mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 22:04:45 +08:00
Gather qmm batched kernel and refactoring of quantized (#2078)
This commit is contained in:

committed by
GitHub

parent
99eefd2ec0
commit
5de6d94a90
@@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[128, 256], # M
|
||||
[32, 128, 256], # M
|
||||
[128, 256, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
if M < group_size:
|
||||
continue
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
@@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
for kwargs in inputs:
|
||||
test_shape(1, 32, 128, **kwargs)
|
||||
test_shape(32, 32, 256, **kwargs)
|
||||
test_shape(1, 32, 256, **kwargs)
|
||||
test_shape(32, 256, 32, transpose=False, **kwargs)
|
||||
@@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
||||
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||
|
||||
def test_gather_qmm_sorted(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
parameters = [
|
||||
# L, K, D, E, I, transpose
|
||||
(128, 1024, 1024, 32, 4, True),
|
||||
(128, 1024, 544, 32, 4, True),
|
||||
(433, 1024, 1024, 32, 4, True),
|
||||
(433, 1024, 555, 32, 4, True),
|
||||
(433, 2048, 1024, 32, 4, True),
|
||||
(128, 1024, 1024, 32, 4, False),
|
||||
(128, 1024, 544, 32, 4, False),
|
||||
(433, 1024, 1024, 32, 4, False),
|
||||
(433, 1024, 544, 32, 4, False),
|
||||
(433, 1024, 555, 32, 4, False),
|
||||
(433, 2048, 1024, 32, 4, False),
|
||||
]
|
||||
for L, K, D, E, I, transpose in parameters:
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
wshape = (E, D, K) if transpose else (E, K, D)
|
||||
|
||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||
x = mx.random.normal(xshape) / K**0.5
|
||||
w = mx.random.normal(wshape) / K**0.5
|
||||
w, *wq = quantize(w, transpose=transpose)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
y4 = mx.gather_qmm(
|
||||
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user