mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix matrix sdpa
This commit is contained in:
@@ -171,7 +171,7 @@ template <
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
@@ -234,11 +234,10 @@ template <
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
// TODO condition here is wrong
|
||||
if (has_sinks) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = static_cast<AccumType>(sinks[tidl.y]);
|
||||
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||
sum_score[i] = 1;
|
||||
}
|
||||
}
|
||||
@@ -361,7 +360,7 @@ template <
|
||||
Stile.frag_at(i, j)[jj] =
|
||||
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
||||
} else {
|
||||
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
|
||||
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,7 +731,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||
|
||||
for T_kv in [128, 4096]:
|
||||
for T_q in [1]: # , 128]:
|
||||
for T_q in [1, 128]:
|
||||
for N_kv in [2, 8]:
|
||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
|
||||
Reference in New Issue
Block a user