mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 23:38:09 +08:00
Complete 2 pass sdpav
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
@@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
U k[v_per_thread];
|
||||
U o[v_per_thread];
|
||||
|
||||
__shared__ U outputs[BD][BN + 1];
|
||||
__shared__ U outputs[BN][BD + 1];
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale; // * 1.44269504089f;
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
@@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
// Adjust to thread block and thread
|
||||
const int batch_idx = 0; // blockIdx.z / blocks;
|
||||
const int batch_idx = blockIdx.z / blocks;
|
||||
const int block_idx = blockIdx.z % blocks;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||
@@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = expf(max_score - new_max);
|
||||
U exp_score = expf(score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = expf(max_score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||
|
||||
@@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max);
|
||||
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||
block.sync();
|
||||
|
||||
if (warp_idx == 0) {
|
||||
U ot = outputs[lane_idx][0];
|
||||
|
||||
U ot = outputs[0][lane_idx];
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 1; j < BN; j++) {
|
||||
ot += outputs[lane_idx][0];
|
||||
ot += outputs[j][lane_idx];
|
||||
warp.sync();
|
||||
}
|
||||
|
||||
// o[i] = ot;
|
||||
partials[v_per_thread * lane_idx + i] = ot;
|
||||
o[i] = ot;
|
||||
}
|
||||
block.sync();
|
||||
}
|
||||
|
||||
// if(warp_idx == 0) {
|
||||
// PRAGMA_LOOP_UNROLL
|
||||
// for (int i = 0; i < v_per_thread; i++) {
|
||||
// partials[v_per_thread * lane_idx + i] = o[i];
|
||||
// }
|
||||
// }
|
||||
if (warp_idx == 0) {
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
partials[v_per_thread * lane_idx + i] = o[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool do_causal, int D>
|
||||
@@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2(
|
||||
|
||||
U max_score = maxs[lane_idx];
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = expf(max_score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||
// sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
@@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2(
|
||||
outputs[lane_idx][warp_idx] = o[i];
|
||||
block.sync();
|
||||
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score;
|
||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||
block.sync();
|
||||
}
|
||||
|
||||
@@ -504,6 +504,7 @@ void sdpa_vector_1pass_fallback(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
@@ -585,6 +586,7 @@ void sdpa_vector_2pass_fallback(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
@@ -610,6 +612,7 @@ void sdpa_vector_2pass_fallback(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
@@ -632,7 +635,7 @@ void sdpa_vector_fallback(
|
||||
bool do_causal_ = false) {
|
||||
int kL = k.shape(2);
|
||||
|
||||
if (false && kL > 1024) {
|
||||
if (kL > 1024) {
|
||||
return sdpa_vector_2pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
} else {
|
||||
@@ -1034,7 +1037,23 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
|
||||
Reference in New Issue
Block a user