mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Add sdpa with sinks (#2558)
* add sdpa with sinks * fix 2 pass * fix matrix sdpa * fix perf regression * add to cuda (#2580)
This commit is contained in:
@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
|
||||
constant bool do_causal [[function_constant(22)]];
|
||||
constant bool bool_mask [[function_constant(23)]];
|
||||
constant bool float_mask [[function_constant(24)]];
|
||||
constant bool has_sinks [[function_constant(25)]];
|
||||
|
||||
template <typename T, int D, int V = D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
|
||||
[[buffer(14), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(15), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
|
||||
const constant int& num_q_heads
|
||||
[[buffer(17), function_constant(has_sinks)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.x;
|
||||
const int q_batch_head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int kv_head_idx = q_batch_head_idx / gqa_factor;
|
||||
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||
simd_lid * v_per_thread;
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
bmask += q_batch_head_idx * mask_head_stride +
|
||||
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
fmask += q_batch_head_idx * mask_head_stride +
|
||||
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
|
||||
out += o_offset * V + simd_gid * v_per_thread;
|
||||
@@ -85,6 +89,10 @@ template <typename T, int D, int V = D>
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && simd_gid == 0) {
|
||||
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
|
||||
sum_exp_score = 1;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
} else if (float_mask) {
|
||||
use_key = (fmask[0] >= Limits<T>::finite_min);
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
@@ -107,13 +117,14 @@ template <typename T, int D, int V = D>
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
|
||||
score += static_cast<U>(fmask[0]);
|
||||
}
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
|
||||
[[buffer(16), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(17), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
|
||||
const constant int& num_q_heads
|
||||
[[buffer(19), function_constant(has_sinks)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// Adjust positions
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.x;
|
||||
const int q_batch_head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
|
||||
const int kv_head_idx = q_batch_head_idx / gqa_factor;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride +
|
||||
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
|
||||
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride +
|
||||
bmask += q_batch_head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride +
|
||||
fmask += q_batch_head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && block_idx == 0 && simd_gid == 0) {
|
||||
int q_head_idx = q_batch_head_idx % num_q_heads;
|
||||
max_score = static_cast<U>(sinks[q_head_idx]);
|
||||
sum_exp_score = 1;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
} else if (float_mask) {
|
||||
use_key = (fmask[0] >= Limits<T>::finite_min);
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
@@ -268,6 +289,10 @@ template <typename T, int D, int V = D>
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (score < Limits<T>::finite_min) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (float_mask) {
|
||||
score += fmask[0];
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
|
||||
|
||||
constant bool has_mask [[function_constant(300)]];
|
||||
constant bool do_causal [[function_constant(301)]];
|
||||
constant bool has_sinks [[function_constant(302)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
@@ -82,6 +83,7 @@ template <
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -169,7 +171,7 @@ template <
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
@@ -232,6 +234,14 @@ template <
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
if (has_sinks) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||
sum_score[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
|
||||
if (do_causal) {
|
||||
@@ -350,7 +360,7 @@ template <
|
||||
Stile.frag_at(i, j)[jj] =
|
||||
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
||||
} else {
|
||||
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
|
||||
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false,
|
||||
const std::optional<array>& mask = std::nullopt) {
|
||||
bool do_causal_,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 4;
|
||||
@@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
const bool has_mask = !!mask;
|
||||
const bool has_mask = mask.has_value();
|
||||
const bool do_causal = do_causal_;
|
||||
const bool has_sinks = sinks.has_value();
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301}};
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm
|
||||
<< "_wn" << wn
|
||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||
std::string base_name;
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_attention_",
|
||||
type_to_name(q),
|
||||
"_bq",
|
||||
bq,
|
||||
"_bk",
|
||||
bk,
|
||||
"_bd",
|
||||
bd,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_mask",
|
||||
type_to_name(has_mask ? *mask : q));
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
std::string hash_name;
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_Q_",
|
||||
(align_Q ? 't' : 'n'),
|
||||
"_align_K_",
|
||||
(align_K ? 't' : 'n'),
|
||||
"_has_mask_",
|
||||
(has_mask ? 't' : 'n'),
|
||||
"_do_causal_",
|
||||
(do_causal ? 't' : 'n'),
|
||||
"_has_sinks_",
|
||||
(has_sinks ? 't' : 'n'));
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
@@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (mask) {
|
||||
auto m = *mask;
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
@@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_bytes(mask_params, 5);
|
||||
compute_encoder.set_input_array(m, 6);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 7);
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
@@ -139,7 +157,8 @@ void sdpa_vector(
|
||||
array& out,
|
||||
float scale,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask) {
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -153,30 +172,32 @@ void sdpa_vector(
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||
MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
bool has_sinks = sinks.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 25},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
hash_name += has_sinks ? "_sinks" : "_nosinks";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -207,6 +228,10 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||
compute_encoder.set_bytes(head_stride, 15);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 16);
|
||||
compute_encoder.set_bytes(q.shape(1), 17);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -221,7 +246,8 @@ void sdpa_vector_2pass(
|
||||
array& out,
|
||||
float scale,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask) {
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -267,17 +293,20 @@ void sdpa_vector_2pass(
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
bool has_sinks = sinks.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 25},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
hash_name += has_sinks ? "_sinks" : "_nosinks";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -310,6 +339,10 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||
compute_encoder.set_bytes(head_stride, 17);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 18);
|
||||
compute_encoder.set_bytes(q.shape(1), 19);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
}
|
||||
bool has_arr_mask = inputs.size() > (3 + has_sinks_);
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 8) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
@@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
(strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
auto mask = inputs.size() > 3
|
||||
auto mask = has_arr_mask
|
||||
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
@@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
auto mask = inputs.size() > 3
|
||||
auto mask = has_arr_mask
|
||||
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
|
||||
sdpa_full_self_attention_metal(
|
||||
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
Reference in New Issue
Block a user