From 6a40e1c1767f7f4423546ac7c74517d7ace7a849 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Mon, 24 Mar 2025 12:28:00 -0700 Subject: [PATCH] Fix looping limit in causal attention (#1999) --- mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index a8469e0ff..b672c34e6 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -237,6 +237,7 @@ template < if (do_causal) { int q_max = (tid.x + 1) * BQ + params->qL_off; kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); } // Loop over KV seq length @@ -290,7 +291,7 @@ template < } // Mask out if causal - if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = -metal::numeric_limits::infinity();