mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
7f8ba2a003
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f |
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
@@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
U k[v_per_thread];
|
||||
U o[v_per_thread];
|
||||
|
||||
__shared__ U outputs[BD][BN + 1];
|
||||
__shared__ U outputs[BN][BD + 1];
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale; // * 1.44269504089f;
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
@@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
// Adjust to thread block and thread
|
||||
const int batch_idx = 0; // blockIdx.z / blocks;
|
||||
const int batch_idx = blockIdx.z / blocks;
|
||||
const int block_idx = blockIdx.z % blocks;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||
@@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = expf(max_score - new_max);
|
||||
U exp_score = expf(score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = expf(max_score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||
|
||||
@@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max);
|
||||
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||
block.sync();
|
||||
|
||||
if (warp_idx == 0) {
|
||||
U ot = outputs[lane_idx][0];
|
||||
|
||||
U ot = outputs[0][lane_idx];
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 1; j < BN; j++) {
|
||||
ot += outputs[lane_idx][0];
|
||||
ot += outputs[j][lane_idx];
|
||||
warp.sync();
|
||||
}
|
||||
|
||||
// o[i] = ot;
|
||||
partials[v_per_thread * lane_idx + i] = ot;
|
||||
o[i] = ot;
|
||||
}
|
||||
block.sync();
|
||||
}
|
||||
|
||||
// if(warp_idx == 0) {
|
||||
// PRAGMA_LOOP_UNROLL
|
||||
// for (int i = 0; i < v_per_thread; i++) {
|
||||
// partials[v_per_thread * lane_idx + i] = o[i];
|
||||
// }
|
||||
// }
|
||||
if (warp_idx == 0) {
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
partials[v_per_thread * lane_idx + i] = o[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool do_causal, int D>
|
||||
@@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2(
|
||||
|
||||
U max_score = maxs[lane_idx];
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = expf(max_score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||
// sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
@@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2(
|
||||
outputs[lane_idx][warp_idx] = o[i];
|
||||
block.sync();
|
||||
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score;
|
||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||
block.sync();
|
||||
}
|
||||
|
||||
@@ -499,11 +499,13 @@ void sdpa_vector_1pass_fallback(
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>;
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
@@ -568,8 +570,8 @@ void sdpa_vector_2pass_fallback(
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
{
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>;
|
||||
auto kernel = cu::
|
||||
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
@@ -585,6 +587,7 @@ void sdpa_vector_2pass_fallback(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
@@ -595,8 +598,8 @@ void sdpa_vector_2pass_fallback(
|
||||
}
|
||||
|
||||
{
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>;
|
||||
auto kernel = cu::
|
||||
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||
|
||||
encoder.set_input_array(intermediate);
|
||||
encoder.set_input_array(sums);
|
||||
@@ -610,6 +613,7 @@ void sdpa_vector_2pass_fallback(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
@@ -632,7 +636,7 @@ void sdpa_vector_fallback(
|
||||
bool do_causal_ = false) {
|
||||
int kL = k.shape(2);
|
||||
|
||||
if (false && kL > 1024) {
|
||||
if (kL > 1024) {
|
||||
return sdpa_vector_2pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
} else {
|
||||
@@ -953,7 +957,20 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
|
||||
return has_arr_mask || !sdpa_supported_head_dim;
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
auto& cu_device = cu::device(s.device);
|
||||
|
||||
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||
cu_device.compute_capability_major() >= 8 &&
|
||||
query_sequence_length == key_sequence_length &&
|
||||
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||
|
||||
const bool supported_config =
|
||||
(supported_matrix_config || supported_vector_config);
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
@@ -990,7 +1007,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 1) {
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
@@ -1034,11 +1051,26 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
// return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
@@ -1075,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user