Add rudimentary test for gather_mm with sorted indices

This commit is contained in:
Angelos Katharopoulos
2025-07-03 14:02:33 -07:00
parent 4babc035a3
commit d96a33c776

View File

@@ -1164,7 +1164,35 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gather_mm_sorted(self):
pass
def gather_mm_ref(a, b, rhs):
b = b[rhs]
return a @ b
def gather_mm_test(a, b, rhs):
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
a = mx.random.normal((100, 1, 100))
b = mx.random.normal((8, 100, 100))
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
c1 = gather_mm_ref(a, b, rhs)
c2 = gather_mm_test(a, b, rhs)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
cotan = mx.random.normal(c1.shape)
c1, dc1 = mx.vjp(
lambda a, b: gather_mm_ref(a, b, rhs),
[a, b],
[cotan],
)
c2, dc2 = mx.vjp(
lambda a, b: gather_mm_test(a, b, rhs),
[a, b],
[cotan],
)
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
def test_segmented_mm(self):
def segmented_mm_ref(a, b, s):