1 Commits
awq ... awq-tq

Author SHA1 Message Date
Alex Barron
ae53ed9090 add TesseraQ rounding 2024-12-19 19:35:26 -08:00

View File

@@ -1,9 +1,12 @@
# Learned quantization using AWQ: # Learned quantization using AWQ and TesseraQ:
# References: # References:
# AWQ # AWQ:
# https://arxiv.org/abs/2306.00978 # https://arxiv.org/abs/2306.00978
# https://github.com/mit-han-lab/llm-awq # https://github.com/mit-han-lab/llm-awq
# TesseraQ:
# https://arxiv.org/abs/2410.19103
# https://github.com/Intelligent-Computing-Lab-Yale/TesseraQ
import argparse import argparse
import glob import glob
@@ -13,19 +16,52 @@ from typing import Callable
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np import numpy as np
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from mlx.utils import tree_flatten, tree_map_with_path from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_map_with_path
from mlx_lm.models.base import create_attention_mask from mlx_lm.models.base import create_attention_mask
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights
from tqdm import tqdm from tqdm import tqdm
ROUNDING_THRESHOLDS = [
0.8,
0.65,
0.5,
0.43,
0.38,
0.34,
0.3,
0.27,
0.24,
0.21,
0.18,
0.15,
0.12,
0.10,
0.08,
0.06,
0.04,
0.02,
0.01,
0.005,
]
def mse(x, y): def mse(x, y):
return ((x - y).astype(mx.float32)) ** 2 return ((x - y).astype(mx.float32)) ** 2
def sigmoid(x: mx.array, gamma: float = -0.1, zeta: float = 1.1):
return mx.clip(nn.sigmoid(x) * (zeta - gamma) + gamma, 0, 1)
def sigmoid_inverse(y: mx.array, gamma: float = -0.1, zeta: float = 1.1):
return -mx.log((zeta - gamma) / (y - gamma) - 1)
def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs): def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs):
y = [] y = []
for i in range(0, x.shape[0], batch_size): for i in range(0, x.shape[0], batch_size):
@@ -50,7 +86,7 @@ def search_best_scale(
quantize_func: Callable, quantize_func: Callable,
block: nn.Module | None = None, block: nn.Module | None = None,
layer_kwargs: dict | None = None, layer_kwargs: dict | None = None,
n_grid: int = 20, n_grid: int = 1,
): ):
group = mx.distributed.init() if mx.distributed.is_available() else None group = mx.distributed.init() if mx.distributed.is_available() else None
layer_kwargs = layer_kwargs or {} layer_kwargs = layer_kwargs or {}
@@ -160,7 +196,7 @@ def search_best_clip(
x: mx.array, x: mx.array,
quantize_func: Callable, quantize_func: Callable,
group_size: int, group_size: int,
n_grid: int = 20, n_grid: int = 2,
max_shrink: float = 0.5, max_shrink: float = 0.5,
subsample: int = 4, subsample: int = 4,
batch_size: int = 64, batch_size: int = 64,
@@ -248,6 +284,134 @@ def clip_block(
tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module) tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module)
class RoundQuant(nn.Module):
def __init__(self, module, group_size: int = 64, bits: int = 3):
super().__init__()
self.bits = bits
self.group_size = group_size
self._weight = module.weight
if hasattr(module, "bias"):
self._bias = module.bias
_, self._scales, self._biases = mx.quantize(
self._weight, group_size=group_size, bits=bits
)
self._scales = self._scales[..., mx.newaxis]
self._biases = self._biases[..., mx.newaxis]
self._weight = self._weight.reshape(self._weight.shape[0], -1, group_size)
rounding = self._weight / self._scales
rounding = rounding - mx.floor(rounding)
self.rounding = sigmoid_inverse(rounding)
self.v = mx.zeros_like(self._scales)
def __call__(self, x: mx.array):
q = (self._weight - self._biases) / self._scales
q = mx.floor(q) + sigmoid(self.rounding)
q = mx.clip(q, 0, 2**self.bits - 1)
w = (q * self._scales * 2 * sigmoid(self.v)) + self._biases
w = w.reshape(w.shape[0], -1)
if hasattr(self, "_bias"):
x = mx.addmm(self._bias, x, w.T)
else:
x = x @ w.T
return x
def to_quantized(self, group_size: int = 64, bits: int = 3):
assert (
group_size == self.group_size and bits == self.bits
), "Quantization parameters must match"
w = self._weight
output_dims, input_dims = w.shape[0], w.shape[1] * w.shape[2]
use_bias = hasattr(self, "_bias")
q = (w - self._biases) / self._scales
q = mx.floor(q) + sigmoid(self.rounding)
q = mx.clip(q, 0, 2**bits - 1)
q = q.astype(mx.uint32)
w = q * self._scales * 2 * sigmoid(self.v) + self._biases
w = w.reshape(w.shape[0], -1)
q = q.reshape(q.shape[0], -1)
bitarr = (q[..., mx.newaxis] >> mx.arange(bits, dtype=mx.uint32)) & 1
w_q = bitarr.reshape((q.shape[0], -1, 32))
w_q = (w_q << mx.arange(32, dtype=mx.uint32)).sum(axis=-1)
qlayer = nn.QuantizedLinear(input_dims, output_dims, use_bias, group_size, bits)
new_scales = self._scales * 2 * sigmoid(self.v)
qlayer.weight = w_q
qlayer.scales = new_scales[..., 0]
qlayer.biases = self._biases[..., 0]
if use_bias:
qlayer.bias = self._bias
return qlayer
def round_block(
block: nn.Module,
inputs: mx.array,
outputs: mx.array,
group_size: int = 64,
bits: int = 3,
layer_kwargs: dict | None = None,
batch_size: int = 4,
):
layer_kwargs = layer_kwargs or {}
block.freeze()
leaves = block.leaf_modules()
rounded = tree_map(
lambda m: RoundQuant(m, group_size, bits) if isinstance(m, nn.Linear) else m,
leaves,
is_leaf=nn.Module.is_module,
)
block.update_modules(rounded)
def hard_round(module, threshold: float = 0):
if not isinstance(module, RoundQuant):
return module
score = mx.abs(sigmoid(module.rounding) - 0.5)
value = mx.array(np.quantile(score.astype(mx.float32), q=threshold))
rounding = mx.where(
sigmoid(module.rounding) > value + 0.5, float("inf"), module.rounding
)
module.rounding = mx.where(
sigmoid(module.rounding) <= 0.5 - value, -float("inf"), rounding
)
return module
for threshold in ROUNDING_THRESHOLDS:
print("threshold", threshold)
optimizer = optim.Adam(learning_rate=1e-3)
tree_map(
lambda m: hard_round(m, threshold),
block.leaf_modules(),
is_leaf=nn.Module.is_module,
)
def loss(block, inputs, outputs):
outputs_q = run_layer(block, inputs, **layer_kwargs)
return mse(outputs, outputs_q).mean()
loss_value_and_grad = nn.value_and_grad(block, loss)
for i in range(0, inputs.shape[0], batch_size):
lvalue, grad = loss_value_and_grad(
block, inputs[i : i + batch_size], outputs[i : i + batch_size]
)
if mx.distributed.is_available():
grad = average_gradients(grad)
optimizer.update(block, grad)
mx.eval(block.parameters(), optimizer.state, lvalue)
print(lvalue)
tree_map(hard_round, block.leaf_modules(), is_leaf=nn.Module.is_module)
def awq_quantize( def awq_quantize(
model, model,
inputs: mx.array, inputs: mx.array,
@@ -325,6 +489,16 @@ def awq_quantize(
group_size=group_size, group_size=group_size,
) )
print("Rounding block")
round_block(
block=layer,
inputs=inputs,
outputs=outputs,
group_size=group_size,
bits=bits,
layer_kwargs={"mask": mask},
)
nn.quantize(layer, group_size=group_size, bits=bits) nn.quantize(layer, group_size=group_size, bits=bits)
outputs_q = run_layer(layer, inputs, mask=mask) outputs_q = run_layer(layer, inputs, mask=mask)
loss = mse(outputs, outputs_q).sum() loss = mse(outputs, outputs_q).sum()