Add a test for segmented_mm

This commit is contained in:
Angelos Katharopoulos
2025-07-03 13:49:46 -07:00
parent a8d7b74984
commit 4babc035a3

View File

@@ -1163,6 +1163,71 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gather_mm_sorted(self):
pass
def test_segmented_mm(self):
def segmented_mm_ref(a, b, s):
s = s.tolist()
c = []
for s1, s2 in s:
c.append(a[:, s1:s2] @ b[s1:s2, :])
return mx.stack(c, axis=0)
shapes = [
(10, 10, 10),
(10, 10, 1000),
(1000, 1000, 1000),
]
segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
for M, N, K in shapes:
for s in segments:
segments = []
for i in range(len(s) - 1):
segments.append([s[i], s[i + 1]])
segments = mx.array(segments)
segments = mx.maximum(K - 1, segments.astype(mx.uint32))
a = mx.random.normal((M, K))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a, b, segments)
c2 = mx.segmented_mm(a, b, segments)
self.assertTrue(mx.allclose(c1, c2))
a = mx.random.normal((K, M))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a.T, b, segments)
c2 = mx.segmented_mm(a.T, b, segments)
self.assertTrue(mx.allclose(c1, c2))
a = mx.random.normal((M, K))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a, b.T, segments)
c2 = mx.segmented_mm(a, b.T, segments)
self.assertTrue(mx.allclose(c1, c2))
a = mx.random.normal((K, M))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a.T, b.T, segments)
c2 = mx.segmented_mm(a.T, b.T, segments)
self.assertTrue(mx.allclose(c1, c2))
with self.assertRaises(ValueError):
a = mx.ones((2, 10, 10))
s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)
mx.segmented_mm(a, a, s)
a = mx.ones((10, 1000))
s = mx.random.randint(0, 16, shape=(1000,))
s = mx.zeros(16).at[s].add(1)
s = mx.sort(s)
s = mx.cumsum(s)
s = mx.concatenate([mx.array([0]), s])
s = mx.as_strided(s, (16, 2), (1, 1))
s = mx.reshape(s, (2, 2, 4, 2))
c = mx.segmented_mm(a, a.T, s)
self.assertEqual(c.shape, (2, 2, 4, 10, 10))
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256