From d96a33c7767f64ea6dea1f33852b49a99f017c6e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 3 Jul 2025 14:02:33 -0700 Subject: [PATCH] Add rudimentary test for gather_mm with sorted indices --- python/tests/test_blas.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 3ab01c4ef..2490f3ab2 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1164,7 +1164,35 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) def test_gather_mm_sorted(self): - pass + def gather_mm_ref(a, b, rhs): + b = b[rhs] + return a @ b + + def gather_mm_test(a, b, rhs): + return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True) + + a = mx.random.normal((100, 1, 100)) + b = mx.random.normal((8, 100, 100)) + rhs = mx.sort(mx.random.randint(0, 8, shape=(100,))) + + c1 = gather_mm_ref(a, b, rhs) + c2 = gather_mm_test(a, b, rhs) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + cotan = mx.random.normal(c1.shape) + c1, dc1 = mx.vjp( + lambda a, b: gather_mm_ref(a, b, rhs), + [a, b], + [cotan], + ) + c2, dc2 = mx.vjp( + lambda a, b: gather_mm_test(a, b, rhs), + [a, b], + [cotan], + ) + self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4)) def test_segmented_mm(self): def segmented_mm_ref(a, b, s):