mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
3 Commits
packed-qua
...
awq-tq
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae53ed9090 | ||
|
|
d4ef909d4a | ||
|
|
db109184b7 |
613
llms/mlx_lm/awq.py
Normal file
613
llms/mlx_lm/awq.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# Learned quantization using AWQ and TesseraQ:
|
||||
|
||||
# References:
|
||||
# AWQ:
|
||||
# https://arxiv.org/abs/2306.00978
|
||||
# https://github.com/mit-han-lab/llm-awq
|
||||
# TesseraQ:
|
||||
# https://arxiv.org/abs/2410.19103
|
||||
# https://github.com/Intelligent-Computing-Lab-Yale/TesseraQ
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_map_with_path
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights
|
||||
from tqdm import tqdm
|
||||
|
||||
ROUNDING_THRESHOLDS = [
|
||||
0.8,
|
||||
0.65,
|
||||
0.5,
|
||||
0.43,
|
||||
0.38,
|
||||
0.34,
|
||||
0.3,
|
||||
0.27,
|
||||
0.24,
|
||||
0.21,
|
||||
0.18,
|
||||
0.15,
|
||||
0.12,
|
||||
0.10,
|
||||
0.08,
|
||||
0.06,
|
||||
0.04,
|
||||
0.02,
|
||||
0.01,
|
||||
0.005,
|
||||
]
|
||||
|
||||
|
||||
def mse(x, y):
|
||||
return ((x - y).astype(mx.float32)) ** 2
|
||||
|
||||
|
||||
def sigmoid(x: mx.array, gamma: float = -0.1, zeta: float = 1.1):
|
||||
return mx.clip(nn.sigmoid(x) * (zeta - gamma) + gamma, 0, 1)
|
||||
|
||||
|
||||
def sigmoid_inverse(y: mx.array, gamma: float = -0.1, zeta: float = 1.1):
|
||||
return -mx.log((zeta - gamma) / (y - gamma) - 1)
|
||||
|
||||
|
||||
def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs):
|
||||
y = []
|
||||
for i in range(0, x.shape[0], batch_size):
|
||||
y.append(layer(x[i : i + batch_size], **kwargs))
|
||||
mx.eval(y)
|
||||
y = mx.concatenate(y, axis=0)
|
||||
return y
|
||||
|
||||
|
||||
def dist_split(x: mx.array, group: mx.distributed.Group):
|
||||
B = x.shape[0]
|
||||
N = group.size()
|
||||
assert B % N == 0
|
||||
r = group.rank()
|
||||
local_B = (B + N - 1) // N
|
||||
return x[r * local_B : (r + 1) * local_B]
|
||||
|
||||
|
||||
def search_best_scale(
|
||||
layers: list[nn.Module],
|
||||
x: mx.array,
|
||||
quantize_func: Callable,
|
||||
block: nn.Module | None = None,
|
||||
layer_kwargs: dict | None = None,
|
||||
n_grid: int = 1,
|
||||
):
|
||||
group = mx.distributed.init() if mx.distributed.is_available() else None
|
||||
layer_kwargs = layer_kwargs or {}
|
||||
|
||||
block = block or layers[0]
|
||||
out = block(x, **layer_kwargs)
|
||||
|
||||
x_max = x.abs().mean(axis=(0, 1))
|
||||
|
||||
best_error = float("inf")
|
||||
best_scales = None
|
||||
|
||||
weights = tree_flatten(block.parameters())
|
||||
|
||||
for ratio in tqdm(range(n_grid)):
|
||||
ratio = ratio * 1 / n_grid
|
||||
scales = mx.maximum(x_max**ratio, 1e-4).reshape(-1)
|
||||
scales = scales / (scales.max() * scales.min()).sqrt()
|
||||
for layer in layers:
|
||||
layer.weight = quantize_func(layer.weight * scales) / scales
|
||||
|
||||
out_q = run_layer(block, x, **layer_kwargs)
|
||||
loss = mse(out, out_q).sum()
|
||||
if group is not None:
|
||||
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
|
||||
loss /= out.size
|
||||
mx.eval(loss)
|
||||
is_best = loss < best_error
|
||||
if is_best:
|
||||
best_error = loss
|
||||
best_scales = scales
|
||||
|
||||
# reload the original weights
|
||||
block.load_weights(weights)
|
||||
|
||||
best_scales = best_scales.reshape(-1)
|
||||
mx.eval(best_scales)
|
||||
return best_scales
|
||||
|
||||
|
||||
def apply_scale(prev_op, layers, scales):
|
||||
# Apply the scales to the layers
|
||||
if isinstance(prev_op, nn.Linear):
|
||||
assert len(layers) == 1
|
||||
prev_op.weight = prev_op.weight / scales[:, mx.newaxis]
|
||||
if hasattr(prev_op, "bias"):
|
||||
prev_op.bias = prev_op.bias / scales
|
||||
layers[0].weight = layers[0].weight * scales[mx.newaxis]
|
||||
elif isinstance(prev_op, (nn.LayerNorm, nn.RMSNorm)):
|
||||
prev_op.weight = prev_op.weight / scales
|
||||
if hasattr(prev_op, "bias"):
|
||||
prev_op.bias = prev_op.bias / scales
|
||||
for layer in layers:
|
||||
layer.weight = layer.weight * scales
|
||||
else:
|
||||
raise NotImplementedError(f"Could not apply scale to prev_op: {prev_op}")
|
||||
|
||||
|
||||
def scale_block(
|
||||
block, input_feat, quantize_func: Callable, layer_kwargs: dict | None = None
|
||||
):
|
||||
layers = [
|
||||
block.self_attn.q_proj,
|
||||
block.self_attn.k_proj,
|
||||
block.self_attn.v_proj,
|
||||
]
|
||||
scales = search_best_scale(
|
||||
layers=layers,
|
||||
block=block.self_attn,
|
||||
x=input_feat["q_proj"],
|
||||
quantize_func=quantize_func,
|
||||
layer_kwargs=layer_kwargs,
|
||||
)
|
||||
apply_scale(block.input_layernorm, layers, scales)
|
||||
for name in ["q_proj", "k_proj", "v_proj"]:
|
||||
input_feat[name] = input_feat[name] / scales
|
||||
|
||||
layers = [
|
||||
block.mlp.gate_proj,
|
||||
block.mlp.up_proj,
|
||||
]
|
||||
scales = search_best_scale(
|
||||
block=block.mlp,
|
||||
layers=layers,
|
||||
x=input_feat["gate_proj"],
|
||||
quantize_func=quantize_func,
|
||||
)
|
||||
mlp_norm = getattr(
|
||||
block, "pre_feedforward_layernorm", block.post_attention_layernorm
|
||||
)
|
||||
apply_scale(mlp_norm, layers, scales)
|
||||
for name in ["gate_proj", "up_proj"]:
|
||||
input_feat[name] = input_feat[name] / scales
|
||||
|
||||
layers = [block.mlp.down_proj]
|
||||
scales = search_best_scale(
|
||||
layers=layers,
|
||||
x=input_feat["down_proj"],
|
||||
quantize_func=quantize_func,
|
||||
)
|
||||
apply_scale(block.mlp.up_proj, layers, scales)
|
||||
input_feat["down_proj"] = input_feat["down_proj"] / scales
|
||||
|
||||
|
||||
def search_best_clip(
|
||||
w: mx.array,
|
||||
x: mx.array,
|
||||
quantize_func: Callable,
|
||||
group_size: int,
|
||||
n_grid: int = 2,
|
||||
max_shrink: float = 0.5,
|
||||
subsample: int = 4,
|
||||
batch_size: int = 64,
|
||||
):
|
||||
group = mx.distributed.init() if mx.distributed.is_available() else None
|
||||
|
||||
x = x[:, ::subsample]
|
||||
x = x.reshape(*x.shape[:-1], -1, group_size)
|
||||
|
||||
w_all = w
|
||||
w_max_all = []
|
||||
w_min_all = []
|
||||
|
||||
for b in range(0, w.shape[0], batch_size):
|
||||
w = w_all[b : b + batch_size]
|
||||
|
||||
group_shape = (w.shape[0], w.shape[-1] // group_size)
|
||||
best_error = mx.full(group_shape, float("inf"))
|
||||
best_w_max = mx.zeros((*group_shape, 1), dtype=x.dtype)
|
||||
best_w_min = mx.zeros((*group_shape, 1), dtype=x.dtype)
|
||||
|
||||
w_shape = w.shape
|
||||
|
||||
w = w.reshape(*w.shape[:-1], -1, group_size)
|
||||
out = mx.einsum("btdg,odg->btod", x, w)
|
||||
|
||||
for i in range(int(max_shrink * n_grid)):
|
||||
p = 1 - i / n_grid
|
||||
w_max = p * w.max(axis=-1, keepdims=True)
|
||||
w_min = p * w.min(axis=-1, keepdims=True)
|
||||
w_m = mx.clip(w, w_min, w_max).reshape(w_shape)
|
||||
|
||||
w_q = quantize_func(w_m)
|
||||
|
||||
w_q = w_q.reshape(*w_q.shape[:-1], -1, group_size)
|
||||
out_q = mx.einsum("btdg,odg->btod", x, w_q)
|
||||
|
||||
# Take the mean across the input batch
|
||||
loss = mse(out, out_q).sum(axis=(0, 1))
|
||||
if group is not None:
|
||||
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
|
||||
loss /= out.shape[0] * out.shape[1]
|
||||
best_indices = loss < best_error
|
||||
best_error = mx.where(best_indices, loss, best_error)
|
||||
best_w_max = mx.where(best_indices[..., mx.newaxis], w_max, best_w_max)
|
||||
best_w_min = mx.where(best_indices[..., mx.newaxis], w_min, best_w_min)
|
||||
mx.eval(best_w_max, best_w_min, best_error)
|
||||
|
||||
w_max_all.append(best_w_max)
|
||||
w_min_all.append(best_w_min)
|
||||
|
||||
best_w_max = mx.concatenate(w_max_all, axis=0)
|
||||
best_w_min = mx.concatenate(w_min_all, axis=0)
|
||||
|
||||
w_r = w_all.reshape(*w_all.shape[:-1], -1, group_size)
|
||||
best_w = mx.clip(w_r, best_w_min, best_w_max)
|
||||
best_w = best_w.reshape(w_all.shape)
|
||||
|
||||
mx.eval(best_w)
|
||||
return best_w
|
||||
|
||||
|
||||
def clip_block(
|
||||
block: nn.Module,
|
||||
input_feat: dict[str, mx.array],
|
||||
quantize_func: Callable,
|
||||
group_size: int,
|
||||
):
|
||||
|
||||
def apply_clip(path, module):
|
||||
if (
|
||||
isinstance(module, nn.Linear)
|
||||
and "q_proj" not in path
|
||||
and "k_proj" not in path
|
||||
):
|
||||
name = path.split(".")[-1]
|
||||
best_weight = search_best_clip(
|
||||
module.weight,
|
||||
input_feat[name],
|
||||
quantize_func=quantize_func,
|
||||
group_size=group_size,
|
||||
)
|
||||
module.weight = best_weight
|
||||
|
||||
tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
|
||||
|
||||
class RoundQuant(nn.Module):
|
||||
def __init__(self, module, group_size: int = 64, bits: int = 3):
|
||||
super().__init__()
|
||||
self.bits = bits
|
||||
self.group_size = group_size
|
||||
self._weight = module.weight
|
||||
if hasattr(module, "bias"):
|
||||
self._bias = module.bias
|
||||
|
||||
_, self._scales, self._biases = mx.quantize(
|
||||
self._weight, group_size=group_size, bits=bits
|
||||
)
|
||||
self._scales = self._scales[..., mx.newaxis]
|
||||
self._biases = self._biases[..., mx.newaxis]
|
||||
|
||||
self._weight = self._weight.reshape(self._weight.shape[0], -1, group_size)
|
||||
rounding = self._weight / self._scales
|
||||
rounding = rounding - mx.floor(rounding)
|
||||
self.rounding = sigmoid_inverse(rounding)
|
||||
self.v = mx.zeros_like(self._scales)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
q = (self._weight - self._biases) / self._scales
|
||||
q = mx.floor(q) + sigmoid(self.rounding)
|
||||
q = mx.clip(q, 0, 2**self.bits - 1)
|
||||
w = (q * self._scales * 2 * sigmoid(self.v)) + self._biases
|
||||
w = w.reshape(w.shape[0], -1)
|
||||
|
||||
if hasattr(self, "_bias"):
|
||||
x = mx.addmm(self._bias, x, w.T)
|
||||
else:
|
||||
x = x @ w.T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 3):
|
||||
assert (
|
||||
group_size == self.group_size and bits == self.bits
|
||||
), "Quantization parameters must match"
|
||||
w = self._weight
|
||||
output_dims, input_dims = w.shape[0], w.shape[1] * w.shape[2]
|
||||
use_bias = hasattr(self, "_bias")
|
||||
|
||||
q = (w - self._biases) / self._scales
|
||||
q = mx.floor(q) + sigmoid(self.rounding)
|
||||
q = mx.clip(q, 0, 2**bits - 1)
|
||||
q = q.astype(mx.uint32)
|
||||
|
||||
w = q * self._scales * 2 * sigmoid(self.v) + self._biases
|
||||
w = w.reshape(w.shape[0], -1)
|
||||
|
||||
q = q.reshape(q.shape[0], -1)
|
||||
bitarr = (q[..., mx.newaxis] >> mx.arange(bits, dtype=mx.uint32)) & 1
|
||||
w_q = bitarr.reshape((q.shape[0], -1, 32))
|
||||
w_q = (w_q << mx.arange(32, dtype=mx.uint32)).sum(axis=-1)
|
||||
|
||||
qlayer = nn.QuantizedLinear(input_dims, output_dims, use_bias, group_size, bits)
|
||||
new_scales = self._scales * 2 * sigmoid(self.v)
|
||||
qlayer.weight = w_q
|
||||
qlayer.scales = new_scales[..., 0]
|
||||
qlayer.biases = self._biases[..., 0]
|
||||
if use_bias:
|
||||
qlayer.bias = self._bias
|
||||
|
||||
return qlayer
|
||||
|
||||
|
||||
def round_block(
|
||||
block: nn.Module,
|
||||
inputs: mx.array,
|
||||
outputs: mx.array,
|
||||
group_size: int = 64,
|
||||
bits: int = 3,
|
||||
layer_kwargs: dict | None = None,
|
||||
batch_size: int = 4,
|
||||
):
|
||||
layer_kwargs = layer_kwargs or {}
|
||||
|
||||
block.freeze()
|
||||
leaves = block.leaf_modules()
|
||||
rounded = tree_map(
|
||||
lambda m: RoundQuant(m, group_size, bits) if isinstance(m, nn.Linear) else m,
|
||||
leaves,
|
||||
is_leaf=nn.Module.is_module,
|
||||
)
|
||||
block.update_modules(rounded)
|
||||
|
||||
def hard_round(module, threshold: float = 0):
|
||||
if not isinstance(module, RoundQuant):
|
||||
return module
|
||||
score = mx.abs(sigmoid(module.rounding) - 0.5)
|
||||
value = mx.array(np.quantile(score.astype(mx.float32), q=threshold))
|
||||
rounding = mx.where(
|
||||
sigmoid(module.rounding) > value + 0.5, float("inf"), module.rounding
|
||||
)
|
||||
module.rounding = mx.where(
|
||||
sigmoid(module.rounding) <= 0.5 - value, -float("inf"), rounding
|
||||
)
|
||||
return module
|
||||
|
||||
for threshold in ROUNDING_THRESHOLDS:
|
||||
print("threshold", threshold)
|
||||
optimizer = optim.Adam(learning_rate=1e-3)
|
||||
|
||||
tree_map(
|
||||
lambda m: hard_round(m, threshold),
|
||||
block.leaf_modules(),
|
||||
is_leaf=nn.Module.is_module,
|
||||
)
|
||||
|
||||
def loss(block, inputs, outputs):
|
||||
outputs_q = run_layer(block, inputs, **layer_kwargs)
|
||||
return mse(outputs, outputs_q).mean()
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(block, loss)
|
||||
|
||||
for i in range(0, inputs.shape[0], batch_size):
|
||||
lvalue, grad = loss_value_and_grad(
|
||||
block, inputs[i : i + batch_size], outputs[i : i + batch_size]
|
||||
)
|
||||
if mx.distributed.is_available():
|
||||
grad = average_gradients(grad)
|
||||
optimizer.update(block, grad)
|
||||
mx.eval(block.parameters(), optimizer.state, lvalue)
|
||||
print(lvalue)
|
||||
|
||||
tree_map(hard_round, block.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
|
||||
|
||||
def awq_quantize(
|
||||
model,
|
||||
inputs: mx.array,
|
||||
group_size: int = 64,
|
||||
bits: int = 3,
|
||||
embed_group_size: int = 32,
|
||||
embed_bits: int = 4,
|
||||
):
|
||||
group = mx.distributed.init() if mx.distributed.is_available() else None
|
||||
|
||||
def quantize_func(w):
|
||||
wq = mx.quantize(w, bits=bits, group_size=group_size)
|
||||
return mx.dequantize(*wq, bits=bits, group_size=group_size)
|
||||
|
||||
mask = create_attention_mask(inputs)
|
||||
|
||||
model.model.embed_tokens = model.model.embed_tokens.to_quantized(
|
||||
group_size=embed_group_size, bits=embed_bits
|
||||
)
|
||||
inputs = model.model.embed_tokens(inputs)
|
||||
|
||||
input_feat = {}
|
||||
|
||||
def capture(path, module):
|
||||
if not isinstance(module, nn.Linear):
|
||||
return module
|
||||
|
||||
class Catcher(nn.Module):
|
||||
def __call__(self, x: mx.array):
|
||||
name = path.split(".")[-1]
|
||||
input_feat[name] = x
|
||||
return module(x)
|
||||
|
||||
return Catcher()
|
||||
|
||||
for i, layer in enumerate(model.model.layers):
|
||||
import time
|
||||
|
||||
s = time.time()
|
||||
print(f"Starting block {i}")
|
||||
|
||||
# capture the inputs to each layer
|
||||
orig_leaves = layer.leaf_modules()
|
||||
capture_leaves = tree_map_with_path(
|
||||
capture, orig_leaves, is_leaf=nn.Module.is_module
|
||||
)
|
||||
layer.update_modules(capture_leaves)
|
||||
outputs = run_layer(layer, inputs, mask=mask)
|
||||
layer.update_modules(orig_leaves)
|
||||
del capture_leaves
|
||||
|
||||
nn.quantize(layer, group_size=group_size, bits=bits)
|
||||
outputs_q = run_layer(layer, inputs, mask=mask)
|
||||
loss = mse(outputs, outputs_q).sum()
|
||||
if group is not None:
|
||||
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
|
||||
loss /= outputs.size
|
||||
print("Before Loss", loss, flush=True)
|
||||
layer.update_modules(orig_leaves)
|
||||
del orig_leaves
|
||||
|
||||
print("Scaling block", flush=True)
|
||||
scale_block(
|
||||
block=layer,
|
||||
input_feat=input_feat,
|
||||
quantize_func=quantize_func,
|
||||
layer_kwargs={"mask": mask},
|
||||
)
|
||||
|
||||
print("Clipping block", flush=True)
|
||||
clip_block(
|
||||
block=layer,
|
||||
input_feat=input_feat,
|
||||
quantize_func=quantize_func,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
print("Rounding block")
|
||||
round_block(
|
||||
block=layer,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
layer_kwargs={"mask": mask},
|
||||
)
|
||||
|
||||
nn.quantize(layer, group_size=group_size, bits=bits)
|
||||
outputs_q = run_layer(layer, inputs, mask=mask)
|
||||
loss = mse(outputs, outputs_q).sum()
|
||||
if group is not None:
|
||||
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
|
||||
loss /= outputs.size
|
||||
print("After Loss", loss, flush=True)
|
||||
|
||||
input_feat = {}
|
||||
inputs = outputs
|
||||
|
||||
mx.eval(layer)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
e = time.time()
|
||||
print("Loop time: ", e - s)
|
||||
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = model.lm_head.to_quantized(
|
||||
group_size=embed_group_size, bits=embed_bits
|
||||
)
|
||||
|
||||
|
||||
def load_wikitext(
|
||||
tokenizer, num_samples: int = 32, sequence_length: int = 2048, split: str = "train"
|
||||
) -> mx.array:
|
||||
dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
|
||||
texts = "\n\n".join(dataset["text"])
|
||||
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
|
||||
|
||||
# Select random chunks
|
||||
starts = mx.random.randint(
|
||||
0, len(tokens) - sequence_length - 1, shape=(num_samples, 1)
|
||||
)
|
||||
data = tokens[starts + mx.arange(sequence_length)]
|
||||
if tokenizer.bos_token_id:
|
||||
data = mx.concatenate(
|
||||
[mx.full((*data.shape[:2], 1), tokenizer.bos_token_id), data], axis=-1
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def save_model(
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
config,
|
||||
model_path: Path,
|
||||
mlx_path: str,
|
||||
):
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
mlx_path = Path(mlx_path)
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
config["quantization"] = {"group_size": 64, "bits": 4}
|
||||
|
||||
def update_config(path, module):
|
||||
if hasattr(module, "bits"):
|
||||
config["quantization"][path] = {
|
||||
"group_size": module.group_size,
|
||||
"bits": module.bits,
|
||||
}
|
||||
else:
|
||||
config["quantization"][path] = False
|
||||
|
||||
tree_map_with_path(update_config, model.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", "-m", default="mlx-community/Qwen2.5-7B-Instruct-bf16"
|
||||
)
|
||||
parser.add_argument("--mlx-path", default="mlx_model")
|
||||
parser.add_argument("--bits", type=int, default=3)
|
||||
parser.add_argument("--group-size", type=int, default=64)
|
||||
parser.add_argument("--num-samples", type=int, default=32)
|
||||
parser.add_argument("--sequence-length", type=int, default=2048)
|
||||
parser.add_argument("--seed", type=int, default=123)
|
||||
args = parser.parse_args()
|
||||
|
||||
group = mx.distributed.init() if mx.distributed.is_available() else None
|
||||
|
||||
num_samples = args.num_samples
|
||||
if group is not None and num_samples % group.size() > 0:
|
||||
num_samples += group.size() - num_samples % group.size()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model_path = get_model_path(args.model, revision=None)
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
calibration_data = load_wikitext(tokenizer, args.num_samples, args.sequence_length)
|
||||
|
||||
if group is not None:
|
||||
calibration_data = dist_split(calibration_data, group)
|
||||
|
||||
awq_quantize(model, calibration_data, bits=args.bits, group_size=args.group_size)
|
||||
|
||||
save_model(model, tokenizer, config, model_path, args.mlx_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import json
|
||||
import sys
|
||||
|
||||
@@ -189,8 +188,8 @@ def main():
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
@@ -199,12 +198,7 @@ def main():
|
||||
messages = [{"role": "system", "content": args.system_prompt}]
|
||||
else:
|
||||
messages = []
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": sys.stdin.read() if prompt == "-" else prompt,
|
||||
}
|
||||
)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
@@ -23,7 +23,12 @@ class BaseModelArgs:
|
||||
)
|
||||
|
||||
|
||||
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
||||
def create_causal_mask(
|
||||
N: int,
|
||||
offset: int = 0,
|
||||
window_size: Optional[int] = None,
|
||||
lengths: Optional[mx.array] = None,
|
||||
):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
linds = linds[:, None]
|
||||
@@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
||||
mask = linds < rinds
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
if lengths is not None:
|
||||
lengths = lengths[:, None, None, None]
|
||||
mask = mask | (rinds >= lengths)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
|
||||
@@ -155,11 +155,13 @@ class CohereModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -180,9 +182,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .cache import KVCache, RotatingKVCache
|
||||
|
||||
|
||||
@@ -151,16 +151,13 @@ class CohereModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
offset = cache[0].offset if cache else 0
|
||||
mask = create_causal_mask(T, offset).astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -181,9 +178,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
@@ -197,11 +197,13 @@ class DBRX(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
@@ -223,9 +225,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
|
||||
@@ -211,9 +211,11 @@ class DeepseekModel(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -236,8 +238,9 @@ class Model(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -395,8 +398,9 @@ class Model(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -123,10 +123,12 @@ class ExaoneModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
@@ -149,9 +151,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -138,12 +138,14 @@ class GemmaModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -164,9 +166,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -160,12 +160,14 @@ class GemmaModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -187,9 +189,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = mx.tanh(out / self.final_logit_softcapping)
|
||||
out = out * self.final_logit_softcapping
|
||||
|
||||
@@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
@@ -138,7 +139,8 @@ class GPT2Model(nn.Module):
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
@@ -159,9 +161,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.wte.as_linear(out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
B, L = inputs.shape
|
||||
@@ -149,7 +150,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
@@ -172,9 +174,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.embed_in(inputs)
|
||||
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
@@ -176,9 +178,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -239,11 +239,13 @@ class HunYuanModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -266,9 +268,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.model.embed_tokens.as_linear(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -193,11 +193,13 @@ class InternLM2Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -220,9 +222,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.tok_embeddings.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -155,11 +155,13 @@ class LlamaModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -182,9 +184,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -186,9 +188,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||
|
||||
@@ -162,11 +162,13 @@ class MixtralModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -188,9 +190,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -176,11 +176,13 @@ class NemotronModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -203,9 +205,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -124,11 +124,13 @@ class Transformer(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
@@ -152,9 +154,10 @@ class OlmoModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
return self.transformer(inputs, cache)
|
||||
return self.transformer(inputs, mask, cache)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -167,9 +170,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
return self.model(inputs, cache)
|
||||
return self.model(inputs, mask, cache)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
||||
@@ -163,10 +163,12 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -190,8 +192,9 @@ class Model(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, cache, mask)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -178,11 +178,13 @@ class OpenELMModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.token_embeddings(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -205,9 +207,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.share_input_output_layers:
|
||||
out = self.transformer.token_embeddings.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -143,10 +143,11 @@ class PhiModel(nn.Module):
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -167,9 +168,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.model(x, cache)
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
|
||||
@@ -168,11 +168,13 @@ class Phi3Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -194,9 +196,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
|
||||
@@ -258,13 +258,15 @@ class Phi3Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
if self.mup_embedding_multiplier:
|
||||
h = self.mup_embedding_multiplier * h
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -290,9 +292,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
if self.mup_width_multiplier:
|
||||
out = out / self.mup_width_multiplier
|
||||
|
||||
@@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -181,9 +183,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -175,7 +175,9 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@@ -174,10 +174,12 @@ class PlamoModel(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
@@ -202,8 +204,9 @@ class Model(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
|
||||
@@ -123,7 +123,8 @@ class QwenModel(nn.Module):
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
@@ -149,11 +149,13 @@ class Qwen2Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -176,9 +178,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -213,9 +215,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
@@ -389,6 +389,7 @@ class Griffin(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
tokens,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
x = self.embed_tokens(tokens)
|
||||
@@ -402,7 +403,8 @@ class Griffin(nn.Module):
|
||||
if block.temporal_block_type != "recurrent":
|
||||
mask_cache = [cache[i]]
|
||||
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
@@ -418,12 +420,12 @@ class Model(nn.Module):
|
||||
self.model_type = config.model_type
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
tokens: Sequence of input tokens.
|
||||
"""
|
||||
logits = self.model(tokens, cache=cache)
|
||||
logits = self.model(tokens, mask=mask, cache=cache)
|
||||
if "lm_head" in self:
|
||||
logits = self.lm_head(logits)
|
||||
else:
|
||||
|
||||
@@ -199,7 +199,10 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
|
||||
@@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -152,9 +154,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ def make_sampler(
|
||||
top_p: float = 0.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
top_k: int = -1,
|
||||
) -> Callable[mx.array, mx.array]:
|
||||
"""
|
||||
Make a sampler function for use with ``generate_step``.
|
||||
@@ -25,6 +26,8 @@ def make_sampler(
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
top_k (int, optional): The top k tokens ranked by probability to constrain
|
||||
the sampling to.
|
||||
|
||||
Returns:
|
||||
Callable[mx.array, mx.array]:
|
||||
@@ -36,6 +39,8 @@ def make_sampler(
|
||||
return lambda x: top_p_sampling(x, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||
elif top_k > 0:
|
||||
return lambda x: top_k_sampling(x, top_k, temp)
|
||||
else:
|
||||
return lambda x: categorical_sampling(x, temp)
|
||||
|
||||
@@ -79,6 +84,33 @@ def make_logits_processors(
|
||||
return logits_processors
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def top_k_sampling(
|
||||
logprobs: mx.array,
|
||||
top_k: int,
|
||||
temperature=1.0,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Sample from only the top K tokens ranked by probability.
|
||||
|
||||
Args:
|
||||
logprobs: A vector of log probabilities.
|
||||
top_k (int): Top k tokens to sample from.
|
||||
"""
|
||||
vocab_size = logprobs.shape[-1]
|
||||
if not isinstance(top_k, int) or not (0 < top_k < vocab_size):
|
||||
raise ValueError(
|
||||
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
|
||||
f" but is {top_k}."
|
||||
)
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
|
||||
masked_logprobs = mx.put_along_axis(
|
||||
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
|
||||
)
|
||||
return mx.random.categorical(masked_logprobs, axis=-1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def min_p_sampling(
|
||||
logprobs: mx.array,
|
||||
@@ -87,7 +119,7 @@ def min_p_sampling(
|
||||
temperature=1.0,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply min-p sampling to the logits.
|
||||
Apply min-p sampling to the logprobs.
|
||||
|
||||
Min-p keeps all tokens that are above a minimum probability, scaled by the
|
||||
probability of the most likely token. As a result, the filter is more
|
||||
|
||||
@@ -32,6 +32,7 @@ setup(
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"mlx_lm.awq = mlx_lm.awq:main",
|
||||
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
||||
"mlx_lm.chat = mlx_lm.chat:main",
|
||||
"mlx_lm.convert = mlx_lm.convert:main",
|
||||
|
||||
@@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models import rope_utils
|
||||
from mlx_lm.models.base import create_causal_mask
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
@@ -128,6 +129,22 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def test_causal_mask_lengths(self):
|
||||
mx.random.seed(8)
|
||||
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||
lengths = mx.array([1, 2, 3, 1])
|
||||
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||
v = k
|
||||
mask = create_causal_mask(T_q, 0, lengths=lengths)
|
||||
|
||||
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
|
||||
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
|
||||
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
|
||||
|
||||
def test_rope(self):
|
||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
@@ -162,10 +179,16 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
cache = make_prompt_cache(model)
|
||||
outputs = model(inputs, cache)
|
||||
outputs = model(inputs, cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
|
||||
from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling
|
||||
|
||||
|
||||
class TestSampleUtils(unittest.TestCase):
|
||||
@@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase):
|
||||
token = min_p_sampling(logits, 0.05)
|
||||
self.assertTrue(token in (0, 3))
|
||||
|
||||
def test_top_k_sampling(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
|
||||
token = top_k_sampling(logits, 1).item()
|
||||
self.assertEqual(token, 0)
|
||||
|
||||
probs = mx.array([0.5, 0.0, 0.0, 0.5])[None]
|
||||
tokens = set()
|
||||
for _ in range(100):
|
||||
token = top_k_sampling(logits, 2)
|
||||
tokens.add(token.item())
|
||||
self.assertEqual(tokens, {0, 3})
|
||||
|
||||
# Batch mode works
|
||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||
logits = mx.log(probs)
|
||||
|
||||
tokens = top_k_sampling(logits, 1)
|
||||
self.assertEqual(tokens.tolist(), [0, 1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user