mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add to cuda (#2580)
This commit is contained in:
@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
const T* K,
|
||||
const T* V,
|
||||
T* O,
|
||||
const T* sinks,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
const U scale_log2 = params.scale * M_LOG2E;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
const T* Q,
|
||||
const T* K,
|
||||
const T* V,
|
||||
const T* sinks,
|
||||
float* partials,
|
||||
float* sums,
|
||||
float* maxs,
|
||||
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0 && block_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -468,6 +480,9 @@ void sdpa_vector_1pass_fallback(
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
encoder.set_output_array(o);
|
||||
|
||||
cu::AttnParams params{
|
||||
@@ -505,6 +520,7 @@ void sdpa_vector_1pass_fallback(
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
o.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
params);
|
||||
});
|
||||
});
|
||||
@@ -572,6 +588,10 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.set_output_array(sums);
|
||||
encoder.set_output_array(maxs);
|
||||
@@ -587,6 +607,7 @@ void sdpa_vector_2pass_fallback(
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
|
||||
Reference in New Issue
Block a user