Add more tests and fix qmm gradient

This commit is contained in:
Angelos Katharopoulos
2025-07-05 02:41:39 -07:00
parent 3d4174cd37
commit 9e5bb5295a
2 changed files with 56 additions and 8 deletions

View File

@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
if lhs is not None:
x = x[lhs]
if rhs is not None:
w = w[rhs]
s = s[rhs]
b = b[rhs]
return mx.quantized_matmul(x, w, s, b, transpose=trans)
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
return mx.gather_qmm(
x,
w,
s,
b,
transpose=trans,
lhs_indices=lhs,
rhs_indices=rhs,
sorted_indices=sort,
)
x = mx.random.normal((128, 1, 1024))
w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024)))
indices = mx.sort(mx.random.randint(0, 8, shape=(128,)))
cotan = mx.random.normal((128, 1, 1024))
(o1,), (dx1, ds1, db1) = mx.vjp(
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
(o2,), (dx2, ds2, db2) = mx.vjp(
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))