mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
Fix looping limit in causal attention (#1999)
This commit is contained in:
parent
9307b2ab8b
commit
6a40e1c176
@ -237,6 +237,7 @@ template <
|
||||
if (do_causal) {
|
||||
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||
kb_lim = (q_max + BK - 1) / BK;
|
||||
kb_lim = min(params->NK, kb_lim);
|
||||
}
|
||||
|
||||
// Loop over KV seq length
|
||||
@ -290,7 +291,7 @@ template <
|
||||
}
|
||||
|
||||
// Mask out if causal
|
||||
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
|
||||
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
|
Loading…
Reference in New Issue
Block a user