diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index fb277b530..fbee6118f 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -81,6 +81,7 @@ inline void segmented_mm( uint32_t k_end = segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; if (k_end <= k_start) { + std::fill_n(out + i * M * N, M * N, T(0)); continue; } a_copy[ndim - 1] = k_end - k_start; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 2490f3ab2..014e8a9dd 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1207,38 +1207,38 @@ class TestBlas(mlx_tests.MLXTestCase): (10, 10, 1000), (1000, 1000, 1000), ] - segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] + all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] for M, N, K in shapes: - for s in segments: + for s in all_segments: segments = [] for i in range(len(s) - 1): segments.append([s[i], s[i + 1]]) segments = mx.array(segments) - segments = mx.maximum(K - 1, segments.astype(mx.uint32)) + segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32)) a = mx.random.normal((M, K)) b = mx.random.normal((K, N)) c1 = segmented_mm_ref(a, b, segments) c2 = mx.segmented_mm(a, b, segments) - self.assertTrue(mx.allclose(c1, c2)) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) a = mx.random.normal((K, M)) b = mx.random.normal((K, N)) c1 = segmented_mm_ref(a.T, b, segments) c2 = mx.segmented_mm(a.T, b, segments) - self.assertTrue(mx.allclose(c1, c2)) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) a = mx.random.normal((M, K)) b = mx.random.normal((N, K)) c1 = segmented_mm_ref(a, b.T, segments) c2 = mx.segmented_mm(a, b.T, segments) - self.assertTrue(mx.allclose(c1, c2)) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) a = mx.random.normal((K, M)) b = mx.random.normal((N, K)) c1 = segmented_mm_ref(a.T, b.T, segments) c2 = mx.segmented_mm(a.T, b.T, segments) - self.assertTrue(mx.allclose(c1, c2)) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) with self.assertRaises(ValueError): a = mx.ones((2, 10, 10))