diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index cd0680131..fb277b530 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -80,6 +80,9 @@ inline void segmented_mm( segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; uint32_t k_end = segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; + if (k_end <= k_start) { + continue; + } a_copy[ndim - 1] = k_end - k_start; b_copy[ndim - 2] = k_end - k_start; matmul(