diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 079a0baff4..f600a48900 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -288,11 +288,9 @@ void ScaledDotProductAttention::eval_gpu( strides[0] == strides[1] * shape[1]; }; - // Checks that the last two dims are row contiguous. + // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return strides[3] == 1 && strides[2] == shape[3]; + return arr.strides(3) == 1; }; // We are in vector mode ie single query