mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc81342afe | ||
|
|
77d75f3ccc |
@@ -1,12 +1,9 @@
|
|||||||
# Learned quantization using AWQ and TesseraQ:
|
# Learned quantization using AWQ:
|
||||||
|
|
||||||
# References:
|
# References:
|
||||||
# AWQ:
|
# AWQ
|
||||||
# https://arxiv.org/abs/2306.00978
|
# https://arxiv.org/abs/2306.00978
|
||||||
# https://github.com/mit-han-lab/llm-awq
|
# https://github.com/mit-han-lab/llm-awq
|
||||||
# TesseraQ:
|
|
||||||
# https://arxiv.org/abs/2410.19103
|
|
||||||
# https://github.com/Intelligent-Computing-Lab-Yale/TesseraQ
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
@@ -16,52 +13,19 @@ from typing import Callable
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.utils import tree_flatten, tree_map_with_path
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_map_with_path
|
|
||||||
from mlx_lm.models.base import create_attention_mask
|
from mlx_lm.models.base import create_attention_mask
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights
|
from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
ROUNDING_THRESHOLDS = [
|
|
||||||
0.8,
|
|
||||||
0.65,
|
|
||||||
0.5,
|
|
||||||
0.43,
|
|
||||||
0.38,
|
|
||||||
0.34,
|
|
||||||
0.3,
|
|
||||||
0.27,
|
|
||||||
0.24,
|
|
||||||
0.21,
|
|
||||||
0.18,
|
|
||||||
0.15,
|
|
||||||
0.12,
|
|
||||||
0.10,
|
|
||||||
0.08,
|
|
||||||
0.06,
|
|
||||||
0.04,
|
|
||||||
0.02,
|
|
||||||
0.01,
|
|
||||||
0.005,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def mse(x, y):
|
def mse(x, y):
|
||||||
return ((x - y).astype(mx.float32)) ** 2
|
return ((x - y).astype(mx.float32)) ** 2
|
||||||
|
|
||||||
|
|
||||||
def sigmoid(x: mx.array, gamma: float = -0.1, zeta: float = 1.1):
|
|
||||||
return mx.clip(nn.sigmoid(x) * (zeta - gamma) + gamma, 0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_inverse(y: mx.array, gamma: float = -0.1, zeta: float = 1.1):
|
|
||||||
return -mx.log((zeta - gamma) / (y - gamma) - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs):
|
def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs):
|
||||||
y = []
|
y = []
|
||||||
for i in range(0, x.shape[0], batch_size):
|
for i in range(0, x.shape[0], batch_size):
|
||||||
@@ -86,7 +50,7 @@ def search_best_scale(
|
|||||||
quantize_func: Callable,
|
quantize_func: Callable,
|
||||||
block: nn.Module | None = None,
|
block: nn.Module | None = None,
|
||||||
layer_kwargs: dict | None = None,
|
layer_kwargs: dict | None = None,
|
||||||
n_grid: int = 1,
|
n_grid: int = 20,
|
||||||
):
|
):
|
||||||
group = mx.distributed.init() if mx.distributed.is_available() else None
|
group = mx.distributed.init() if mx.distributed.is_available() else None
|
||||||
layer_kwargs = layer_kwargs or {}
|
layer_kwargs = layer_kwargs or {}
|
||||||
@@ -196,7 +160,7 @@ def search_best_clip(
|
|||||||
x: mx.array,
|
x: mx.array,
|
||||||
quantize_func: Callable,
|
quantize_func: Callable,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
n_grid: int = 2,
|
n_grid: int = 20,
|
||||||
max_shrink: float = 0.5,
|
max_shrink: float = 0.5,
|
||||||
subsample: int = 4,
|
subsample: int = 4,
|
||||||
batch_size: int = 64,
|
batch_size: int = 64,
|
||||||
@@ -284,134 +248,6 @@ def clip_block(
|
|||||||
tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module)
|
tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||||
|
|
||||||
|
|
||||||
class RoundQuant(nn.Module):
|
|
||||||
def __init__(self, module, group_size: int = 64, bits: int = 3):
|
|
||||||
super().__init__()
|
|
||||||
self.bits = bits
|
|
||||||
self.group_size = group_size
|
|
||||||
self._weight = module.weight
|
|
||||||
if hasattr(module, "bias"):
|
|
||||||
self._bias = module.bias
|
|
||||||
|
|
||||||
_, self._scales, self._biases = mx.quantize(
|
|
||||||
self._weight, group_size=group_size, bits=bits
|
|
||||||
)
|
|
||||||
self._scales = self._scales[..., mx.newaxis]
|
|
||||||
self._biases = self._biases[..., mx.newaxis]
|
|
||||||
|
|
||||||
self._weight = self._weight.reshape(self._weight.shape[0], -1, group_size)
|
|
||||||
rounding = self._weight / self._scales
|
|
||||||
rounding = rounding - mx.floor(rounding)
|
|
||||||
self.rounding = sigmoid_inverse(rounding)
|
|
||||||
self.v = mx.zeros_like(self._scales)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
q = (self._weight - self._biases) / self._scales
|
|
||||||
q = mx.floor(q) + sigmoid(self.rounding)
|
|
||||||
q = mx.clip(q, 0, 2**self.bits - 1)
|
|
||||||
w = (q * self._scales * 2 * sigmoid(self.v)) + self._biases
|
|
||||||
w = w.reshape(w.shape[0], -1)
|
|
||||||
|
|
||||||
if hasattr(self, "_bias"):
|
|
||||||
x = mx.addmm(self._bias, x, w.T)
|
|
||||||
else:
|
|
||||||
x = x @ w.T
|
|
||||||
return x
|
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 3):
|
|
||||||
assert (
|
|
||||||
group_size == self.group_size and bits == self.bits
|
|
||||||
), "Quantization parameters must match"
|
|
||||||
w = self._weight
|
|
||||||
output_dims, input_dims = w.shape[0], w.shape[1] * w.shape[2]
|
|
||||||
use_bias = hasattr(self, "_bias")
|
|
||||||
|
|
||||||
q = (w - self._biases) / self._scales
|
|
||||||
q = mx.floor(q) + sigmoid(self.rounding)
|
|
||||||
q = mx.clip(q, 0, 2**bits - 1)
|
|
||||||
q = q.astype(mx.uint32)
|
|
||||||
|
|
||||||
w = q * self._scales * 2 * sigmoid(self.v) + self._biases
|
|
||||||
w = w.reshape(w.shape[0], -1)
|
|
||||||
|
|
||||||
q = q.reshape(q.shape[0], -1)
|
|
||||||
bitarr = (q[..., mx.newaxis] >> mx.arange(bits, dtype=mx.uint32)) & 1
|
|
||||||
w_q = bitarr.reshape((q.shape[0], -1, 32))
|
|
||||||
w_q = (w_q << mx.arange(32, dtype=mx.uint32)).sum(axis=-1)
|
|
||||||
|
|
||||||
qlayer = nn.QuantizedLinear(input_dims, output_dims, use_bias, group_size, bits)
|
|
||||||
new_scales = self._scales * 2 * sigmoid(self.v)
|
|
||||||
qlayer.weight = w_q
|
|
||||||
qlayer.scales = new_scales[..., 0]
|
|
||||||
qlayer.biases = self._biases[..., 0]
|
|
||||||
if use_bias:
|
|
||||||
qlayer.bias = self._bias
|
|
||||||
|
|
||||||
return qlayer
|
|
||||||
|
|
||||||
|
|
||||||
def round_block(
|
|
||||||
block: nn.Module,
|
|
||||||
inputs: mx.array,
|
|
||||||
outputs: mx.array,
|
|
||||||
group_size: int = 64,
|
|
||||||
bits: int = 3,
|
|
||||||
layer_kwargs: dict | None = None,
|
|
||||||
batch_size: int = 4,
|
|
||||||
):
|
|
||||||
layer_kwargs = layer_kwargs or {}
|
|
||||||
|
|
||||||
block.freeze()
|
|
||||||
leaves = block.leaf_modules()
|
|
||||||
rounded = tree_map(
|
|
||||||
lambda m: RoundQuant(m, group_size, bits) if isinstance(m, nn.Linear) else m,
|
|
||||||
leaves,
|
|
||||||
is_leaf=nn.Module.is_module,
|
|
||||||
)
|
|
||||||
block.update_modules(rounded)
|
|
||||||
|
|
||||||
def hard_round(module, threshold: float = 0):
|
|
||||||
if not isinstance(module, RoundQuant):
|
|
||||||
return module
|
|
||||||
score = mx.abs(sigmoid(module.rounding) - 0.5)
|
|
||||||
value = mx.array(np.quantile(score.astype(mx.float32), q=threshold))
|
|
||||||
rounding = mx.where(
|
|
||||||
sigmoid(module.rounding) > value + 0.5, float("inf"), module.rounding
|
|
||||||
)
|
|
||||||
module.rounding = mx.where(
|
|
||||||
sigmoid(module.rounding) <= 0.5 - value, -float("inf"), rounding
|
|
||||||
)
|
|
||||||
return module
|
|
||||||
|
|
||||||
for threshold in ROUNDING_THRESHOLDS:
|
|
||||||
print("threshold", threshold)
|
|
||||||
optimizer = optim.Adam(learning_rate=1e-3)
|
|
||||||
|
|
||||||
tree_map(
|
|
||||||
lambda m: hard_round(m, threshold),
|
|
||||||
block.leaf_modules(),
|
|
||||||
is_leaf=nn.Module.is_module,
|
|
||||||
)
|
|
||||||
|
|
||||||
def loss(block, inputs, outputs):
|
|
||||||
outputs_q = run_layer(block, inputs, **layer_kwargs)
|
|
||||||
return mse(outputs, outputs_q).mean()
|
|
||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(block, loss)
|
|
||||||
|
|
||||||
for i in range(0, inputs.shape[0], batch_size):
|
|
||||||
lvalue, grad = loss_value_and_grad(
|
|
||||||
block, inputs[i : i + batch_size], outputs[i : i + batch_size]
|
|
||||||
)
|
|
||||||
if mx.distributed.is_available():
|
|
||||||
grad = average_gradients(grad)
|
|
||||||
optimizer.update(block, grad)
|
|
||||||
mx.eval(block.parameters(), optimizer.state, lvalue)
|
|
||||||
print(lvalue)
|
|
||||||
|
|
||||||
tree_map(hard_round, block.leaf_modules(), is_leaf=nn.Module.is_module)
|
|
||||||
|
|
||||||
|
|
||||||
def awq_quantize(
|
def awq_quantize(
|
||||||
model,
|
model,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
@@ -489,16 +325,6 @@ def awq_quantize(
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Rounding block")
|
|
||||||
round_block(
|
|
||||||
block=layer,
|
|
||||||
inputs=inputs,
|
|
||||||
outputs=outputs,
|
|
||||||
group_size=group_size,
|
|
||||||
bits=bits,
|
|
||||||
layer_kwargs={"mask": mask},
|
|
||||||
)
|
|
||||||
|
|
||||||
nn.quantize(layer, group_size=group_size, bits=bits)
|
nn.quantize(layer, group_size=group_size, bits=bits)
|
||||||
outputs_q = run_layer(layer, inputs, mask=mask)
|
outputs_q = run_layer(layer, inputs, mask=mask)
|
||||||
loss = mse(outputs, outputs_q).sum()
|
loss = mse(outputs, outputs_q).sum()
|
||||||
|
|||||||
Reference in New Issue
Block a user