Fix the test and cpu edge case

This commit is contained in:
Angelos Katharopoulos
2025-07-04 18:36:20 -07:00
parent bd0622c4d9
commit 2d0f452aae
2 changed files with 8 additions and 7 deletions

View File

@@ -81,6 +81,7 @@ inline void segmented_mm(
uint32_t k_end = uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) { if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue; continue;
} }
a_copy[ndim - 1] = k_end - k_start; a_copy[ndim - 1] = k_end - k_start;

View File

@@ -1207,38 +1207,38 @@ class TestBlas(mlx_tests.MLXTestCase):
(10, 10, 1000), (10, 10, 1000),
(1000, 1000, 1000), (1000, 1000, 1000),
] ]
segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
for M, N, K in shapes: for M, N, K in shapes:
for s in segments: for s in all_segments:
segments = [] segments = []
for i in range(len(s) - 1): for i in range(len(s) - 1):
segments.append([s[i], s[i + 1]]) segments.append([s[i], s[i + 1]])
segments = mx.array(segments) segments = mx.array(segments)
segments = mx.maximum(K - 1, segments.astype(mx.uint32)) segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
a = mx.random.normal((M, K)) a = mx.random.normal((M, K))
b = mx.random.normal((K, N)) b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a, b, segments) c1 = segmented_mm_ref(a, b, segments)
c2 = mx.segmented_mm(a, b, segments) c2 = mx.segmented_mm(a, b, segments)
self.assertTrue(mx.allclose(c1, c2)) self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M)) a = mx.random.normal((K, M))
b = mx.random.normal((K, N)) b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a.T, b, segments) c1 = segmented_mm_ref(a.T, b, segments)
c2 = mx.segmented_mm(a.T, b, segments) c2 = mx.segmented_mm(a.T, b, segments)
self.assertTrue(mx.allclose(c1, c2)) self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((M, K)) a = mx.random.normal((M, K))
b = mx.random.normal((N, K)) b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a, b.T, segments) c1 = segmented_mm_ref(a, b.T, segments)
c2 = mx.segmented_mm(a, b.T, segments) c2 = mx.segmented_mm(a, b.T, segments)
self.assertTrue(mx.allclose(c1, c2)) self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M)) a = mx.random.normal((K, M))
b = mx.random.normal((N, K)) b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a.T, b.T, segments) c1 = segmented_mm_ref(a.T, b.T, segments)
c2 = mx.segmented_mm(a.T, b.T, segments) c2 = mx.segmented_mm(a.T, b.T, segments)
self.assertTrue(mx.allclose(c1, c2)) self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
a = mx.ones((2, 10, 10)) a = mx.ones((2, 10, 10))