fix matrix sdpa

This commit is contained in:
Awni Hannun
2025-09-09 12:52:56 -07:00
parent 0fe25eb588
commit 836f019d3b
2 changed files with 4 additions and 5 deletions

View File

@@ -171,7 +171,7 @@ template <
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
@@ -234,11 +234,10 @@ template <
max_score[i] = Limits<AccumType>::finite_min;
}
// TODO condition here is wrong
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = static_cast<AccumType>(sinks[tidl.y]);
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
@@ -361,7 +360,7 @@ template <
Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
}
}
}

View File

@@ -731,7 +731,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
for T_kv in [128, 4096]:
for T_q in [1]: # , 128]:
for T_q in [1, 128]:
for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))