Complete 2 pass sdpav

This commit is contained in:
Jagrit Digani
2025-08-06 13:57:40 -07:00
parent 7f8ba2a003
commit f81edd184f

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
@@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1(
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BD][BN + 1];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale; // * 1.44269504089f;
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
@@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1(
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = 0; // blockIdx.z / blocks;
const int batch_idx = blockIdx.z / blocks;
const int block_idx = blockIdx.z % blocks;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
@@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1(
// Update the accumulators
U new_max = max(max_score, score);
U factor = expf(max_score - new_max);
U exp_score = expf(score - new_max);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1(
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = expf(max_score - new_max);
U factor = exp2f(max_score - new_max);
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
@@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1(
}
// Now we need to aggregate all the outputs
auto ff = exp2f(max_scores[warp_idx] - new_max);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max);
outputs[warp_idx][lane_idx] = o[i] * ff;
block.sync();
if (warp_idx == 0) {
U ot = outputs[lane_idx][0];
U ot = outputs[0][lane_idx];
PRAGMA_LOOP_UNROLL
for (int j = 1; j < BN; j++) {
ot += outputs[lane_idx][0];
ot += outputs[j][lane_idx];
warp.sync();
}
// o[i] = ot;
partials[v_per_thread * lane_idx + i] = ot;
o[i] = ot;
}
block.sync();
}
// if(warp_idx == 0) {
// PRAGMA_LOOP_UNROLL
// for (int i = 0; i < v_per_thread; i++) {
// partials[v_per_thread * lane_idx + i] = o[i];
// }
// }
if (warp_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = o[i];
}
}
}
template <typename T, bool do_causal, int D>
@@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2(
U max_score = maxs[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = expf(max_score - new_max);
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
// sum_exp_score = __frcp_rn(sum_exp_score);
sum_exp_score = __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
@@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2(
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
@@ -504,6 +504,7 @@ void sdpa_vector_1pass_fallback(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
@@ -585,6 +586,7 @@ void sdpa_vector_2pass_fallback(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
@@ -610,6 +612,7 @@ void sdpa_vector_2pass_fallback(
kernel,
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
@@ -632,7 +635,7 @@ void sdpa_vector_fallback(
bool do_causal_ = false) {
int kL = k.shape(2);
if (false && kL > 1024) {
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
} else {
@@ -1034,7 +1037,23 @@ void ScaledDotProductAttention::eval_gpu(
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
o.set_data(allocator::malloc(o.nbytes()));
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);