From 4babc035a374f7887355c5107c30857964fe8829 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 3 Jul 2025 13:49:46 -0700 Subject: [PATCH] Add a test for segmented_mm --- python/tests/test_blas.py | 65 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index eb45df124..3ab01c4ef 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1163,6 +1163,71 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_gather_mm_sorted(self): + pass + + def test_segmented_mm(self): + def segmented_mm_ref(a, b, s): + s = s.tolist() + c = [] + for s1, s2 in s: + c.append(a[:, s1:s2] @ b[s1:s2, :]) + return mx.stack(c, axis=0) + + shapes = [ + (10, 10, 10), + (10, 10, 1000), + (1000, 1000, 1000), + ] + segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] + + for M, N, K in shapes: + for s in segments: + segments = [] + for i in range(len(s) - 1): + segments.append([s[i], s[i + 1]]) + segments = mx.array(segments) + segments = mx.maximum(K - 1, segments.astype(mx.uint32)) + a = mx.random.normal((M, K)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a, b, segments) + c2 = mx.segmented_mm(a, b, segments) + self.assertTrue(mx.allclose(c1, c2)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a.T, b, segments) + c2 = mx.segmented_mm(a.T, b, segments) + self.assertTrue(mx.allclose(c1, c2)) + + a = mx.random.normal((M, K)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a, b.T, segments) + c2 = mx.segmented_mm(a, b.T, segments) + self.assertTrue(mx.allclose(c1, c2)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a.T, b.T, segments) + c2 = mx.segmented_mm(a.T, b.T, segments) + self.assertTrue(mx.allclose(c1, c2)) + + with self.assertRaises(ValueError): + a = mx.ones((2, 10, 10)) + s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32) + mx.segmented_mm(a, a, s) + + a = mx.ones((10, 1000)) + s = mx.random.randint(0, 16, shape=(1000,)) + s = mx.zeros(16).at[s].add(1) + s = mx.sort(s) + s = mx.cumsum(s) + s = mx.concatenate([mx.array([0]), s]) + s = mx.as_strided(s, (16, 2), (1, 1)) + s = mx.reshape(s, (2, 2, 4, 2)) + c = mx.segmented_mm(a, a.T, s) + self.assertEqual(c.shape, (2, 2, 4, 10, 10)) + def test_gemv_gemm_same_precision(self): mx.random.seed(0) N = 256