Gather qmm batched kernel and refactoring of quantized (#2078)

This commit is contained in:
Angelos Katharopoulos
2025-04-17 13:53:11 -07:00
committed by GitHub
parent 99eefd2ec0
commit 5de6d94a90
15 changed files with 1479 additions and 449 deletions

View File

@@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[128, 256], # M
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
if M < group_size:
continue
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
@@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
)
for kwargs in inputs:
test_shape(1, 32, 128, **kwargs)
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
@@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
if __name__ == "__main__":
unittest.main()