mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 01:08:10 +08:00
Complete 2 pass sdpav
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
U k[v_per_thread];
|
U k[v_per_thread];
|
||||||
U o[v_per_thread];
|
U o[v_per_thread];
|
||||||
|
|
||||||
__shared__ U outputs[BD][BN + 1];
|
__shared__ U outputs[BN][BD + 1];
|
||||||
__shared__ U max_scores[BN];
|
__shared__ U max_scores[BN];
|
||||||
__shared__ U sum_exp_scores[BN];
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
const U scale_log2 = params.scale; // * 1.44269504089f;
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<32>(block);
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
@@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
const int warp_idx = warp.meta_group_rank();
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
// Adjust to thread block and thread
|
// Adjust to thread block and thread
|
||||||
const int batch_idx = 0; // blockIdx.z / blocks;
|
const int batch_idx = blockIdx.z / blocks;
|
||||||
const int block_idx = blockIdx.z % blocks;
|
const int block_idx = blockIdx.z % blocks;
|
||||||
const int head_idx = blockIdx.x;
|
const int head_idx = blockIdx.x;
|
||||||
const int kv_head_idx = head_idx / params.gqa_factor;
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
@@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
|
|
||||||
// Update the accumulators
|
// Update the accumulators
|
||||||
U new_max = max(max_score, score);
|
U new_max = max(max_score, score);
|
||||||
U factor = expf(max_score - new_max);
|
U factor = exp2f(max_score - new_max);
|
||||||
U exp_score = expf(score - new_max);
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
max_score = new_max;
|
max_score = new_max;
|
||||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
@@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
|
|
||||||
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
U factor = expf(max_score - new_max);
|
U factor = exp2f(max_score - new_max);
|
||||||
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||||
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||||
|
|
||||||
@@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
// Now we need to aggregate all the outputs
|
||||||
|
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||||
PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
for (int i = 0; i < v_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max);
|
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||||
block.sync();
|
block.sync();
|
||||||
|
|
||||||
if (warp_idx == 0) {
|
if (warp_idx == 0) {
|
||||||
U ot = outputs[lane_idx][0];
|
U ot = outputs[0][lane_idx];
|
||||||
|
|
||||||
PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
for (int j = 1; j < BN; j++) {
|
for (int j = 1; j < BN; j++) {
|
||||||
ot += outputs[lane_idx][0];
|
ot += outputs[j][lane_idx];
|
||||||
|
warp.sync();
|
||||||
}
|
}
|
||||||
|
o[i] = ot;
|
||||||
// o[i] = ot;
|
|
||||||
partials[v_per_thread * lane_idx + i] = ot;
|
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
// if(warp_idx == 0) {
|
if (warp_idx == 0) {
|
||||||
// PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
// for (int i = 0; i < v_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
// partials[v_per_thread * lane_idx + i] = o[i];
|
partials[v_per_thread * lane_idx + i] = o[i];
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool do_causal, int D>
|
template <typename T, bool do_causal, int D>
|
||||||
@@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2(
|
|||||||
|
|
||||||
U max_score = maxs[lane_idx];
|
U max_score = maxs[lane_idx];
|
||||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
U factor = expf(max_score - new_max);
|
U factor = exp2f(max_score - new_max);
|
||||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||||
// sum_exp_score = __frcp_rn(sum_exp_score);
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
for (int i = 0; i < v_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
@@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2(
|
|||||||
outputs[lane_idx][warp_idx] = o[i];
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
block.sync();
|
block.sync();
|
||||||
U ot = outputs[warp_idx][lane_idx] * factor;
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score;
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
block.sync();
|
block.sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,6 +504,7 @@ void sdpa_vector_1pass_fallback(
|
|||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
q.data<DataType>(),
|
q.data<DataType>(),
|
||||||
k.data<DataType>(),
|
k.data<DataType>(),
|
||||||
v.data<DataType>(),
|
v.data<DataType>(),
|
||||||
@@ -585,6 +586,7 @@ void sdpa_vector_2pass_fallback(
|
|||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
q.data<DataType>(),
|
q.data<DataType>(),
|
||||||
k.data<DataType>(),
|
k.data<DataType>(),
|
||||||
v.data<DataType>(),
|
v.data<DataType>(),
|
||||||
@@ -610,6 +612,7 @@ void sdpa_vector_2pass_fallback(
|
|||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
intermediate.data<float>(),
|
intermediate.data<float>(),
|
||||||
sums.data<float>(),
|
sums.data<float>(),
|
||||||
maxs.data<float>(),
|
maxs.data<float>(),
|
||||||
@@ -632,7 +635,7 @@ void sdpa_vector_fallback(
|
|||||||
bool do_causal_ = false) {
|
bool do_causal_ = false) {
|
||||||
int kL = k.shape(2);
|
int kL = k.shape(2);
|
||||||
|
|
||||||
if (false && kL > 1024) {
|
if (kL > 1024) {
|
||||||
return sdpa_vector_2pass_fallback(
|
return sdpa_vector_2pass_fallback(
|
||||||
s, encoder, q, k, v, scale, o, do_causal_);
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
} else {
|
} else {
|
||||||
@@ -1034,7 +1037,23 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||||
o.copy_shared_buffer(q);
|
o.copy_shared_buffer(q);
|
||||||
} else {
|
} else {
|
||||||
o.set_data(allocator::malloc(o.nbytes()));
|
int64_t str_oD = 1;
|
||||||
|
int64_t str_oH = o.shape(3);
|
||||||
|
int64_t str_oL = o.shape(1) * str_oH;
|
||||||
|
int64_t str_oB = o.shape(2) * str_oL;
|
||||||
|
size_t data_size = o.shape(0) * str_oB;
|
||||||
|
|
||||||
|
array::Flags flags{
|
||||||
|
/* bool contiguous = */ 1,
|
||||||
|
/* bool row_contiguous = */ 0,
|
||||||
|
/* bool col_contiguous = */ 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
o.set_data(
|
||||||
|
allocator::malloc(o.nbytes()),
|
||||||
|
data_size,
|
||||||
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
|
flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
|
|||||||
Reference in New Issue
Block a user