[WIP] Add init QMMs (some CI tests failing)

This commit is contained in:
Jagrit Digani
2025-11-18 12:12:59 -08:00
parent a0558cd3af
commit c532eb94c1
11 changed files with 2035 additions and 87 deletions

View File

@@ -834,47 +834,54 @@ class TestQuantized(mlx_tests.MLXTestCase):
(64, 512, 512, 4, 2, False, "affine"),
]
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
else:
group_size = 64
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
if mode == "mxfp4":
group_size = 32
else:
group_size = 64
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(
w, group_size=group_size, mode=mode, transpose=transpose
)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
y4 = mx.gather_qmm(
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
self.assertLess((y1 - y2).abs().max(), 1e-5)
self.assertLess((y1 - y3).abs().max(), 1e-5)
self.assertLess((y1 - y4).abs().max(), 2e-4)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=2e-4))
def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):