Fix the test and cpu edge case

This commit is contained in:
Angelos Katharopoulos
2025-07-04 18:36:20 -07:00
parent bd0622c4d9
commit 2d0f452aae
2 changed files with 8 additions and 7 deletions

View File

@@ -81,6 +81,7 @@ inline void segmented_mm(
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue;
}
a_copy[ndim - 1] = k_end - k_start;

View File

@@ -1207,38 +1207,38 @@ class TestBlas(mlx_tests.MLXTestCase):
(10, 10, 1000),
(1000, 1000, 1000),
]
segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
for M, N, K in shapes:
for s in segments:
for s in all_segments:
segments = []
for i in range(len(s) - 1):
segments.append([s[i], s[i + 1]])
segments = mx.array(segments)
segments = mx.maximum(K - 1, segments.astype(mx.uint32))
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
a = mx.random.normal((M, K))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a, b, segments)
c2 = mx.segmented_mm(a, b, segments)
self.assertTrue(mx.allclose(c1, c2))
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a.T, b, segments)
c2 = mx.segmented_mm(a.T, b, segments)
self.assertTrue(mx.allclose(c1, c2))
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((M, K))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a, b.T, segments)
c2 = mx.segmented_mm(a, b.T, segments)
self.assertTrue(mx.allclose(c1, c2))
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a.T, b.T, segments)
c2 = mx.segmented_mm(a.T, b.T, segments)
self.assertTrue(mx.allclose(c1, c2))
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
with self.assertRaises(ValueError):
a = mx.ones((2, 10, 10))