Fix looping limit in causal attention (#1999)

This commit is contained in:
Jagrit Digani 2025-03-24 12:28:00 -07:00 committed by GitHub
parent 9307b2ab8b
commit 6a40e1c176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -237,6 +237,7 @@ template <
if (do_causal) { if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off; int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK; kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
} }
// Loop over KV seq length // Loop over KV seq length
@ -290,7 +291,7 @@ template <
} }
// Mask out if causal // Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();