add to cuda (#2580)

This commit is contained in:
Awni Hannun
2025-09-09 13:59:45 -07:00
committed by GitHub
parent ce54db388f
commit 3dbbcb4c7c

View File

@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
const T* K, const T* K,
const T* V, const T* V,
T* O, T* O,
const T* sinks,
__grid_constant__ const AttnParams params) { __grid_constant__ const AttnParams params) {
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 32;
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN]; __shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN]; __shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f; const U scale_log2 = params.scale * M_LOG2E;
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block); auto warp = cg::tiled_partition<32>(block);
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
U max_score = -INFINITY; U max_score = -INFINITY;
U sum_exp_score = 0.f; U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key // For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) { for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max); bool is_neg_inf = new_max == -INFINITY;
U exp_score = exp2f(score - new_max); U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q, const T* Q,
const T* K, const T* K,
const T* V, const T* V,
const T* sinks,
float* partials, float* partials,
float* sums, float* sums,
float* maxs, float* maxs,
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f; o[i] = 0.f;
} }
U max_score = -1e9; U max_score = -INFINITY;
U sum_exp_score = 0.f; U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key // For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max); bool is_neg_inf = new_max == -INFINITY;
U exp_score = exp2f(score - new_max); U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
@@ -468,6 +480,9 @@ void sdpa_vector_1pass_fallback(
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
encoder.set_input_array(v); encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o); encoder.set_output_array(o);
cu::AttnParams params{ cu::AttnParams params{
@@ -505,6 +520,7 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(), k.data<DataType>(),
v.data<DataType>(), v.data<DataType>(),
o.data<DataType>(), o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params); params);
}); });
}); });
@@ -572,6 +588,10 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
encoder.set_input_array(v); encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.set_output_array(sums); encoder.set_output_array(sums);
encoder.set_output_array(maxs); encoder.set_output_array(maxs);
@@ -587,6 +607,7 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(), q.data<DataType>(),
k.data<DataType>(), k.data<DataType>(),
v.data<DataType>(), v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(), intermediate.data<float>(),
sums.data<float>(), sums.data<float>(),
maxs.data<float>(), maxs.data<float>(),