mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
7f8ba2a003
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f |
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@@ -217,11 +218,11 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
U k[v_per_thread];
|
U k[v_per_thread];
|
||||||
U o[v_per_thread];
|
U o[v_per_thread];
|
||||||
|
|
||||||
__shared__ U outputs[BD][BN + 1];
|
__shared__ U outputs[BN][BD + 1];
|
||||||
__shared__ U max_scores[BN];
|
__shared__ U max_scores[BN];
|
||||||
__shared__ U sum_exp_scores[BN];
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
const U scale_log2 = params.scale; // * 1.44269504089f;
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<32>(block);
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
@@ -230,7 +231,7 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
const int warp_idx = warp.meta_group_rank();
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
// Adjust to thread block and thread
|
// Adjust to thread block and thread
|
||||||
const int batch_idx = 0; // blockIdx.z / blocks;
|
const int batch_idx = blockIdx.z / blocks;
|
||||||
const int block_idx = blockIdx.z % blocks;
|
const int block_idx = blockIdx.z % blocks;
|
||||||
const int head_idx = blockIdx.x;
|
const int head_idx = blockIdx.x;
|
||||||
const int kv_head_idx = head_idx / params.gqa_factor;
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
@@ -302,8 +303,8 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
|
|
||||||
// Update the accumulators
|
// Update the accumulators
|
||||||
U new_max = max(max_score, score);
|
U new_max = max(max_score, score);
|
||||||
U factor = expf(max_score - new_max);
|
U factor = exp2f(max_score - new_max);
|
||||||
U exp_score = expf(score - new_max);
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
max_score = new_max;
|
max_score = new_max;
|
||||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
@@ -330,7 +331,7 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
|
|
||||||
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
U factor = expf(max_score - new_max);
|
U factor = exp2f(max_score - new_max);
|
||||||
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||||
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||||
|
|
||||||
@@ -341,31 +342,30 @@ __global__ void kernel_sdpav_2pass_1(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
// Now we need to aggregate all the outputs
|
||||||
|
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||||
PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
for (int i = 0; i < v_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max);
|
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||||
block.sync();
|
block.sync();
|
||||||
|
|
||||||
if (warp_idx == 0) {
|
if (warp_idx == 0) {
|
||||||
U ot = outputs[lane_idx][0];
|
U ot = outputs[0][lane_idx];
|
||||||
|
|
||||||
PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
for (int j = 1; j < BN; j++) {
|
for (int j = 1; j < BN; j++) {
|
||||||
ot += outputs[lane_idx][0];
|
ot += outputs[j][lane_idx];
|
||||||
|
warp.sync();
|
||||||
}
|
}
|
||||||
|
o[i] = ot;
|
||||||
// o[i] = ot;
|
|
||||||
partials[v_per_thread * lane_idx + i] = ot;
|
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
// if(warp_idx == 0) {
|
if (warp_idx == 0) {
|
||||||
// PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
// for (int i = 0; i < v_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
// partials[v_per_thread * lane_idx + i] = o[i];
|
partials[v_per_thread * lane_idx + i] = o[i];
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool do_causal, int D>
|
template <typename T, bool do_causal, int D>
|
||||||
@@ -414,9 +414,9 @@ __global__ void kernel_sdpav_2pass_2(
|
|||||||
|
|
||||||
U max_score = maxs[lane_idx];
|
U max_score = maxs[lane_idx];
|
||||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
U factor = expf(max_score - new_max);
|
U factor = exp2f(max_score - new_max);
|
||||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||||
// sum_exp_score = __frcp_rn(sum_exp_score);
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
PRAGMA_LOOP_UNROLL
|
PRAGMA_LOOP_UNROLL
|
||||||
for (int i = 0; i < v_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
@@ -429,7 +429,7 @@ __global__ void kernel_sdpav_2pass_2(
|
|||||||
outputs[lane_idx][warp_idx] = o[i];
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
block.sync();
|
block.sync();
|
||||||
U ot = outputs[warp_idx][lane_idx] * factor;
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score;
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
block.sync();
|
block.sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -499,11 +499,13 @@ void sdpa_vector_1pass_fallback(
|
|||||||
dispatch_headdim(params.D, [&](auto headdim) {
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>;
|
auto kernel =
|
||||||
|
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
q.data<DataType>(),
|
q.data<DataType>(),
|
||||||
k.data<DataType>(),
|
k.data<DataType>(),
|
||||||
v.data<DataType>(),
|
v.data<DataType>(),
|
||||||
@@ -568,8 +570,8 @@ void sdpa_vector_2pass_fallback(
|
|||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto kernel =
|
auto kernel = cu::
|
||||||
cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>;
|
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
@@ -585,6 +587,7 @@ void sdpa_vector_2pass_fallback(
|
|||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
q.data<DataType>(),
|
q.data<DataType>(),
|
||||||
k.data<DataType>(),
|
k.data<DataType>(),
|
||||||
v.data<DataType>(),
|
v.data<DataType>(),
|
||||||
@@ -595,8 +598,8 @@ void sdpa_vector_2pass_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto kernel =
|
auto kernel = cu::
|
||||||
cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>;
|
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
encoder.set_input_array(intermediate);
|
encoder.set_input_array(intermediate);
|
||||||
encoder.set_input_array(sums);
|
encoder.set_input_array(sums);
|
||||||
@@ -610,6 +613,7 @@ void sdpa_vector_2pass_fallback(
|
|||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
intermediate.data<float>(),
|
intermediate.data<float>(),
|
||||||
sums.data<float>(),
|
sums.data<float>(),
|
||||||
maxs.data<float>(),
|
maxs.data<float>(),
|
||||||
@@ -632,7 +636,7 @@ void sdpa_vector_fallback(
|
|||||||
bool do_causal_ = false) {
|
bool do_causal_ = false) {
|
||||||
int kL = k.shape(2);
|
int kL = k.shape(2);
|
||||||
|
|
||||||
if (false && kL > 1024) {
|
if (kL > 1024) {
|
||||||
return sdpa_vector_2pass_fallback(
|
return sdpa_vector_2pass_fallback(
|
||||||
s, encoder, q, k, v, scale, o, do_causal_);
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
} else {
|
} else {
|
||||||
@@ -953,7 +957,20 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||||
|
|
||||||
return has_arr_mask || !sdpa_supported_head_dim;
|
const bool supported_vector_config =
|
||||||
|
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||||
|
|
||||||
|
auto& cu_device = cu::device(s.device);
|
||||||
|
|
||||||
|
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||||
|
cu_device.compute_capability_major() >= 8 &&
|
||||||
|
query_sequence_length == key_sequence_length &&
|
||||||
|
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||||
|
|
||||||
|
const bool supported_config =
|
||||||
|
(supported_matrix_config || supported_vector_config);
|
||||||
|
|
||||||
|
return has_arr_mask || !supported_config;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
@@ -990,7 +1007,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) <= 1) {
|
if (q_pre.shape(2) < 4) {
|
||||||
auto q_copy_unless = [](const array& arr) {
|
auto q_copy_unless = [](const array& arr) {
|
||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return true;
|
return true;
|
||||||
@@ -1034,11 +1051,26 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||||
o.copy_shared_buffer(q);
|
o.copy_shared_buffer(q);
|
||||||
} else {
|
} else {
|
||||||
o.set_data(allocator::malloc(o.nbytes()));
|
int64_t str_oD = 1;
|
||||||
|
int64_t str_oH = o.shape(3);
|
||||||
|
int64_t str_oL = o.shape(1) * str_oH;
|
||||||
|
int64_t str_oB = o.shape(2) * str_oL;
|
||||||
|
size_t data_size = o.shape(0) * str_oB;
|
||||||
|
|
||||||
|
array::Flags flags{
|
||||||
|
/* bool contiguous = */ 1,
|
||||||
|
/* bool row_contiguous = */ 0,
|
||||||
|
/* bool col_contiguous = */ 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
o.set_data(
|
||||||
|
allocator::malloc(o.nbytes()),
|
||||||
|
data_size,
|
||||||
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
|
flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
// return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode
|
// Full attention mode
|
||||||
@@ -1075,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user