mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix matrix sdpa
This commit is contained in:
@@ -171,7 +171,7 @@ template <
|
|||||||
VBlockLoader loader_v(
|
VBlockLoader loader_v(
|
||||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
|
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
|
||||||
|
|
||||||
// Prepare MMA tiles
|
// Prepare MMA tiles
|
||||||
constexpr short kFragSize = 8; // MMAFrag size
|
constexpr short kFragSize = 8; // MMAFrag size
|
||||||
@@ -234,11 +234,10 @@ template <
|
|||||||
max_score[i] = Limits<AccumType>::finite_min;
|
max_score[i] = Limits<AccumType>::finite_min;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO condition here is wrong
|
|
||||||
if (has_sinks) {
|
if (has_sinks) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
max_score[i] = static_cast<AccumType>(sinks[tidl.y]);
|
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||||
sum_score[i] = 1;
|
sum_score[i] = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -361,7 +360,7 @@ template <
|
|||||||
Stile.frag_at(i, j)[jj] =
|
Stile.frag_at(i, j)[jj] =
|
||||||
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
||||||
} else {
|
} else {
|
||||||
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
|
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -731,7 +731,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||||
|
|
||||||
for T_kv in [128, 4096]:
|
for T_kv in [128, 4096]:
|
||||||
for T_q in [1]: # , 128]:
|
for T_q in [1, 128]:
|
||||||
for N_kv in [2, 8]:
|
for N_kv in [2, 8]:
|
||||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||||
|
|||||||
Reference in New Issue
Block a user