From ae53ed9090df32f2f49c6e3e140833e13d850d25 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 19 Dec 2024 19:35:26 -0800 Subject: [PATCH] add TesseraQ rounding --- llms/mlx_lm/awq.py | 613 +++++++++++++++++++++++++++++++++++++++++++++ llms/setup.py | 1 + 2 files changed, 614 insertions(+) create mode 100644 llms/mlx_lm/awq.py diff --git a/llms/mlx_lm/awq.py b/llms/mlx_lm/awq.py new file mode 100644 index 00000000..05264c53 --- /dev/null +++ b/llms/mlx_lm/awq.py @@ -0,0 +1,613 @@ +# Learned quantization using AWQ and TesseraQ: + +# References: +# AWQ: +# https://arxiv.org/abs/2306.00978 +# https://github.com/mit-han-lab/llm-awq +# TesseraQ: +# https://arxiv.org/abs/2410.19103 +# https://github.com/Intelligent-Computing-Lab-Yale/TesseraQ + +import argparse +import glob +import shutil +from pathlib import Path +from typing import Callable + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from datasets import Dataset, load_dataset +from mlx.nn.utils import average_gradients +from mlx.utils import tree_flatten, tree_map, tree_map_with_path +from mlx_lm.models.base import create_attention_mask +from mlx_lm.tokenizer_utils import TokenizerWrapper +from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights +from tqdm import tqdm + +ROUNDING_THRESHOLDS = [ + 0.8, + 0.65, + 0.5, + 0.43, + 0.38, + 0.34, + 0.3, + 0.27, + 0.24, + 0.21, + 0.18, + 0.15, + 0.12, + 0.10, + 0.08, + 0.06, + 0.04, + 0.02, + 0.01, + 0.005, +] + + +def mse(x, y): + return ((x - y).astype(mx.float32)) ** 2 + + +def sigmoid(x: mx.array, gamma: float = -0.1, zeta: float = 1.1): + return mx.clip(nn.sigmoid(x) * (zeta - gamma) + gamma, 0, 1) + + +def sigmoid_inverse(y: mx.array, gamma: float = -0.1, zeta: float = 1.1): + return -mx.log((zeta - gamma) / (y - gamma) - 1) + + +def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs): + y = [] + for i in range(0, x.shape[0], batch_size): + y.append(layer(x[i : i + batch_size], **kwargs)) + mx.eval(y) + y = mx.concatenate(y, axis=0) + return y + + +def dist_split(x: mx.array, group: mx.distributed.Group): + B = x.shape[0] + N = group.size() + assert B % N == 0 + r = group.rank() + local_B = (B + N - 1) // N + return x[r * local_B : (r + 1) * local_B] + + +def search_best_scale( + layers: list[nn.Module], + x: mx.array, + quantize_func: Callable, + block: nn.Module | None = None, + layer_kwargs: dict | None = None, + n_grid: int = 1, +): + group = mx.distributed.init() if mx.distributed.is_available() else None + layer_kwargs = layer_kwargs or {} + + block = block or layers[0] + out = block(x, **layer_kwargs) + + x_max = x.abs().mean(axis=(0, 1)) + + best_error = float("inf") + best_scales = None + + weights = tree_flatten(block.parameters()) + + for ratio in tqdm(range(n_grid)): + ratio = ratio * 1 / n_grid + scales = mx.maximum(x_max**ratio, 1e-4).reshape(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + for layer in layers: + layer.weight = quantize_func(layer.weight * scales) / scales + + out_q = run_layer(block, x, **layer_kwargs) + loss = mse(out, out_q).sum() + if group is not None: + loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size() + loss /= out.size + mx.eval(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_scales = scales + + # reload the original weights + block.load_weights(weights) + + best_scales = best_scales.reshape(-1) + mx.eval(best_scales) + return best_scales + + +def apply_scale(prev_op, layers, scales): + # Apply the scales to the layers + if isinstance(prev_op, nn.Linear): + assert len(layers) == 1 + prev_op.weight = prev_op.weight / scales[:, mx.newaxis] + if hasattr(prev_op, "bias"): + prev_op.bias = prev_op.bias / scales + layers[0].weight = layers[0].weight * scales[mx.newaxis] + elif isinstance(prev_op, (nn.LayerNorm, nn.RMSNorm)): + prev_op.weight = prev_op.weight / scales + if hasattr(prev_op, "bias"): + prev_op.bias = prev_op.bias / scales + for layer in layers: + layer.weight = layer.weight * scales + else: + raise NotImplementedError(f"Could not apply scale to prev_op: {prev_op}") + + +def scale_block( + block, input_feat, quantize_func: Callable, layer_kwargs: dict | None = None +): + layers = [ + block.self_attn.q_proj, + block.self_attn.k_proj, + block.self_attn.v_proj, + ] + scales = search_best_scale( + layers=layers, + block=block.self_attn, + x=input_feat["q_proj"], + quantize_func=quantize_func, + layer_kwargs=layer_kwargs, + ) + apply_scale(block.input_layernorm, layers, scales) + for name in ["q_proj", "k_proj", "v_proj"]: + input_feat[name] = input_feat[name] / scales + + layers = [ + block.mlp.gate_proj, + block.mlp.up_proj, + ] + scales = search_best_scale( + block=block.mlp, + layers=layers, + x=input_feat["gate_proj"], + quantize_func=quantize_func, + ) + mlp_norm = getattr( + block, "pre_feedforward_layernorm", block.post_attention_layernorm + ) + apply_scale(mlp_norm, layers, scales) + for name in ["gate_proj", "up_proj"]: + input_feat[name] = input_feat[name] / scales + + layers = [block.mlp.down_proj] + scales = search_best_scale( + layers=layers, + x=input_feat["down_proj"], + quantize_func=quantize_func, + ) + apply_scale(block.mlp.up_proj, layers, scales) + input_feat["down_proj"] = input_feat["down_proj"] / scales + + +def search_best_clip( + w: mx.array, + x: mx.array, + quantize_func: Callable, + group_size: int, + n_grid: int = 2, + max_shrink: float = 0.5, + subsample: int = 4, + batch_size: int = 64, +): + group = mx.distributed.init() if mx.distributed.is_available() else None + + x = x[:, ::subsample] + x = x.reshape(*x.shape[:-1], -1, group_size) + + w_all = w + w_max_all = [] + w_min_all = [] + + for b in range(0, w.shape[0], batch_size): + w = w_all[b : b + batch_size] + + group_shape = (w.shape[0], w.shape[-1] // group_size) + best_error = mx.full(group_shape, float("inf")) + best_w_max = mx.zeros((*group_shape, 1), dtype=x.dtype) + best_w_min = mx.zeros((*group_shape, 1), dtype=x.dtype) + + w_shape = w.shape + + w = w.reshape(*w.shape[:-1], -1, group_size) + out = mx.einsum("btdg,odg->btod", x, w) + + for i in range(int(max_shrink * n_grid)): + p = 1 - i / n_grid + w_max = p * w.max(axis=-1, keepdims=True) + w_min = p * w.min(axis=-1, keepdims=True) + w_m = mx.clip(w, w_min, w_max).reshape(w_shape) + + w_q = quantize_func(w_m) + + w_q = w_q.reshape(*w_q.shape[:-1], -1, group_size) + out_q = mx.einsum("btdg,odg->btod", x, w_q) + + # Take the mean across the input batch + loss = mse(out, out_q).sum(axis=(0, 1)) + if group is not None: + loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size() + loss /= out.shape[0] * out.shape[1] + best_indices = loss < best_error + best_error = mx.where(best_indices, loss, best_error) + best_w_max = mx.where(best_indices[..., mx.newaxis], w_max, best_w_max) + best_w_min = mx.where(best_indices[..., mx.newaxis], w_min, best_w_min) + mx.eval(best_w_max, best_w_min, best_error) + + w_max_all.append(best_w_max) + w_min_all.append(best_w_min) + + best_w_max = mx.concatenate(w_max_all, axis=0) + best_w_min = mx.concatenate(w_min_all, axis=0) + + w_r = w_all.reshape(*w_all.shape[:-1], -1, group_size) + best_w = mx.clip(w_r, best_w_min, best_w_max) + best_w = best_w.reshape(w_all.shape) + + mx.eval(best_w) + return best_w + + +def clip_block( + block: nn.Module, + input_feat: dict[str, mx.array], + quantize_func: Callable, + group_size: int, +): + + def apply_clip(path, module): + if ( + isinstance(module, nn.Linear) + and "q_proj" not in path + and "k_proj" not in path + ): + name = path.split(".")[-1] + best_weight = search_best_clip( + module.weight, + input_feat[name], + quantize_func=quantize_func, + group_size=group_size, + ) + module.weight = best_weight + + tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module) + + +class RoundQuant(nn.Module): + def __init__(self, module, group_size: int = 64, bits: int = 3): + super().__init__() + self.bits = bits + self.group_size = group_size + self._weight = module.weight + if hasattr(module, "bias"): + self._bias = module.bias + + _, self._scales, self._biases = mx.quantize( + self._weight, group_size=group_size, bits=bits + ) + self._scales = self._scales[..., mx.newaxis] + self._biases = self._biases[..., mx.newaxis] + + self._weight = self._weight.reshape(self._weight.shape[0], -1, group_size) + rounding = self._weight / self._scales + rounding = rounding - mx.floor(rounding) + self.rounding = sigmoid_inverse(rounding) + self.v = mx.zeros_like(self._scales) + + def __call__(self, x: mx.array): + q = (self._weight - self._biases) / self._scales + q = mx.floor(q) + sigmoid(self.rounding) + q = mx.clip(q, 0, 2**self.bits - 1) + w = (q * self._scales * 2 * sigmoid(self.v)) + self._biases + w = w.reshape(w.shape[0], -1) + + if hasattr(self, "_bias"): + x = mx.addmm(self._bias, x, w.T) + else: + x = x @ w.T + return x + + def to_quantized(self, group_size: int = 64, bits: int = 3): + assert ( + group_size == self.group_size and bits == self.bits + ), "Quantization parameters must match" + w = self._weight + output_dims, input_dims = w.shape[0], w.shape[1] * w.shape[2] + use_bias = hasattr(self, "_bias") + + q = (w - self._biases) / self._scales + q = mx.floor(q) + sigmoid(self.rounding) + q = mx.clip(q, 0, 2**bits - 1) + q = q.astype(mx.uint32) + + w = q * self._scales * 2 * sigmoid(self.v) + self._biases + w = w.reshape(w.shape[0], -1) + + q = q.reshape(q.shape[0], -1) + bitarr = (q[..., mx.newaxis] >> mx.arange(bits, dtype=mx.uint32)) & 1 + w_q = bitarr.reshape((q.shape[0], -1, 32)) + w_q = (w_q << mx.arange(32, dtype=mx.uint32)).sum(axis=-1) + + qlayer = nn.QuantizedLinear(input_dims, output_dims, use_bias, group_size, bits) + new_scales = self._scales * 2 * sigmoid(self.v) + qlayer.weight = w_q + qlayer.scales = new_scales[..., 0] + qlayer.biases = self._biases[..., 0] + if use_bias: + qlayer.bias = self._bias + + return qlayer + + +def round_block( + block: nn.Module, + inputs: mx.array, + outputs: mx.array, + group_size: int = 64, + bits: int = 3, + layer_kwargs: dict | None = None, + batch_size: int = 4, +): + layer_kwargs = layer_kwargs or {} + + block.freeze() + leaves = block.leaf_modules() + rounded = tree_map( + lambda m: RoundQuant(m, group_size, bits) if isinstance(m, nn.Linear) else m, + leaves, + is_leaf=nn.Module.is_module, + ) + block.update_modules(rounded) + + def hard_round(module, threshold: float = 0): + if not isinstance(module, RoundQuant): + return module + score = mx.abs(sigmoid(module.rounding) - 0.5) + value = mx.array(np.quantile(score.astype(mx.float32), q=threshold)) + rounding = mx.where( + sigmoid(module.rounding) > value + 0.5, float("inf"), module.rounding + ) + module.rounding = mx.where( + sigmoid(module.rounding) <= 0.5 - value, -float("inf"), rounding + ) + return module + + for threshold in ROUNDING_THRESHOLDS: + print("threshold", threshold) + optimizer = optim.Adam(learning_rate=1e-3) + + tree_map( + lambda m: hard_round(m, threshold), + block.leaf_modules(), + is_leaf=nn.Module.is_module, + ) + + def loss(block, inputs, outputs): + outputs_q = run_layer(block, inputs, **layer_kwargs) + return mse(outputs, outputs_q).mean() + + loss_value_and_grad = nn.value_and_grad(block, loss) + + for i in range(0, inputs.shape[0], batch_size): + lvalue, grad = loss_value_and_grad( + block, inputs[i : i + batch_size], outputs[i : i + batch_size] + ) + if mx.distributed.is_available(): + grad = average_gradients(grad) + optimizer.update(block, grad) + mx.eval(block.parameters(), optimizer.state, lvalue) + print(lvalue) + + tree_map(hard_round, block.leaf_modules(), is_leaf=nn.Module.is_module) + + +def awq_quantize( + model, + inputs: mx.array, + group_size: int = 64, + bits: int = 3, + embed_group_size: int = 32, + embed_bits: int = 4, +): + group = mx.distributed.init() if mx.distributed.is_available() else None + + def quantize_func(w): + wq = mx.quantize(w, bits=bits, group_size=group_size) + return mx.dequantize(*wq, bits=bits, group_size=group_size) + + mask = create_attention_mask(inputs) + + model.model.embed_tokens = model.model.embed_tokens.to_quantized( + group_size=embed_group_size, bits=embed_bits + ) + inputs = model.model.embed_tokens(inputs) + + input_feat = {} + + def capture(path, module): + if not isinstance(module, nn.Linear): + return module + + class Catcher(nn.Module): + def __call__(self, x: mx.array): + name = path.split(".")[-1] + input_feat[name] = x + return module(x) + + return Catcher() + + for i, layer in enumerate(model.model.layers): + import time + + s = time.time() + print(f"Starting block {i}") + + # capture the inputs to each layer + orig_leaves = layer.leaf_modules() + capture_leaves = tree_map_with_path( + capture, orig_leaves, is_leaf=nn.Module.is_module + ) + layer.update_modules(capture_leaves) + outputs = run_layer(layer, inputs, mask=mask) + layer.update_modules(orig_leaves) + del capture_leaves + + nn.quantize(layer, group_size=group_size, bits=bits) + outputs_q = run_layer(layer, inputs, mask=mask) + loss = mse(outputs, outputs_q).sum() + if group is not None: + loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size() + loss /= outputs.size + print("Before Loss", loss, flush=True) + layer.update_modules(orig_leaves) + del orig_leaves + + print("Scaling block", flush=True) + scale_block( + block=layer, + input_feat=input_feat, + quantize_func=quantize_func, + layer_kwargs={"mask": mask}, + ) + + print("Clipping block", flush=True) + clip_block( + block=layer, + input_feat=input_feat, + quantize_func=quantize_func, + group_size=group_size, + ) + + print("Rounding block") + round_block( + block=layer, + inputs=inputs, + outputs=outputs, + group_size=group_size, + bits=bits, + layer_kwargs={"mask": mask}, + ) + + nn.quantize(layer, group_size=group_size, bits=bits) + outputs_q = run_layer(layer, inputs, mask=mask) + loss = mse(outputs, outputs_q).sum() + if group is not None: + loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size() + loss /= outputs.size + print("After Loss", loss, flush=True) + + input_feat = {} + inputs = outputs + + mx.eval(layer) + mx.metal.clear_cache() + + e = time.time() + print("Loop time: ", e - s) + + if hasattr(model, "lm_head"): + model.lm_head = model.lm_head.to_quantized( + group_size=embed_group_size, bits=embed_bits + ) + + +def load_wikitext( + tokenizer, num_samples: int = 32, sequence_length: int = 2048, split: str = "train" +) -> mx.array: + dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split) + texts = "\n\n".join(dataset["text"]) + tokens = tokenizer.encode(texts, return_tensors="mlx")[0] + + # Select random chunks + starts = mx.random.randint( + 0, len(tokens) - sequence_length - 1, shape=(num_samples, 1) + ) + data = tokens[starts + mx.arange(sequence_length)] + if tokenizer.bos_token_id: + data = mx.concatenate( + [mx.full((*data.shape[:2], 1), tokenizer.bos_token_id), data], axis=-1 + ) + return data + + +def save_model( + model: nn.Module, + tokenizer: TokenizerWrapper, + config, + model_path: Path, + mlx_path: str, +): + weights = dict(tree_flatten(model.parameters())) + + mlx_path = Path(mlx_path) + save_weights(mlx_path, weights, donate_weights=True) + + py_files = glob.glob(str(model_path / "*.py")) + for file in py_files: + shutil.copy(file, mlx_path) + + tokenizer.save_pretrained(mlx_path) + + config["quantization"] = {"group_size": 64, "bits": 4} + + def update_config(path, module): + if hasattr(module, "bits"): + config["quantization"][path] = { + "group_size": module.group_size, + "bits": module.bits, + } + else: + config["quantization"][path] = False + + tree_map_with_path(update_config, model.leaf_modules(), is_leaf=nn.Module.is_module) + + save_config(config, config_path=mlx_path / "config.json") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", "-m", default="mlx-community/Qwen2.5-7B-Instruct-bf16" + ) + parser.add_argument("--mlx-path", default="mlx_model") + parser.add_argument("--bits", type=int, default=3) + parser.add_argument("--group-size", type=int, default=64) + parser.add_argument("--num-samples", type=int, default=32) + parser.add_argument("--sequence-length", type=int, default=2048) + parser.add_argument("--seed", type=int, default=123) + args = parser.parse_args() + + group = mx.distributed.init() if mx.distributed.is_available() else None + + num_samples = args.num_samples + if group is not None and num_samples % group.size() > 0: + num_samples += group.size() - num_samples % group.size() + + mx.random.seed(args.seed) + + model_path = get_model_path(args.model, revision=None) + model, config, tokenizer = fetch_from_hub(model_path, lazy=True) + + calibration_data = load_wikitext(tokenizer, args.num_samples, args.sequence_length) + + if group is not None: + calibration_data = dist_split(calibration_data, group) + + awq_quantize(model, calibration_data, bits=args.bits, group_size=args.group_size) + + save_model(model, tokenizer, config, model_path, args.mlx_path) + + +if __name__ == "__main__": + main() diff --git a/llms/setup.py b/llms/setup.py index b88dcd33..2a28eae1 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -32,6 +32,7 @@ setup( }, entry_points={ "console_scripts": [ + "mlx_lm.awq = mlx_lm.awq:main", "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main",