Allow different value dimensions in sdpa_vector (#1811)

This commit is contained in:
Angelos Katharopoulos
2025-01-31 20:58:59 -08:00
committed by GitHub
parent b7c9f1d38f
commit f5cc1eea72
6 changed files with 127 additions and 72 deletions

View File

@@ -7,15 +7,34 @@ using namespace metal;
// clang-format off
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim)
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \
"sdpa_vector_2pass_2_" #type "_" #value_dim, \
sdpa_vector_2pass_2, \
type, \
value_dim)
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \
instantiate_sdpa_vector(type, 96) \
instantiate_sdpa_vector(type, 128)
#define instantiate_sdpa_vector(type, qk_dim, value_dim) \
instantiate_kernel( \
"sdpa_vector_" #type "_" #qk_dim "_" #value_dim, \
sdpa_vector, \
type, \
qk_dim, \
value_dim) \
instantiate_kernel( \
"sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \
sdpa_vector_2pass_1, \
type, \
qk_dim, \
value_dim)
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128)
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)

View File

@@ -6,7 +6,7 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]];
template <typename T, int D>
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
@@ -25,14 +25,16 @@ template <typename T, int D>
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
constexpr int qk_per_thread = D / BD;
constexpr int v_per_thread = V / BD;
constexpr int inner_k_stride = BN * D;
constexpr int inner_v_stride = BN * V;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
thread U q[qk_per_thread];
thread U k[qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
@@ -41,19 +43,19 @@ template <typename T, int D>
// Adjust positions
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
queries += head_idx * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
}
out += head_idx * D + simd_gid * elem_per_thread;
out += head_idx * V + simd_gid * v_per_thread;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
}
@@ -64,13 +66,13 @@ template <typename T, int D>
for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) {
// Read the key
for (int j = 0; j < elem_per_thread; j++) {
for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j];
}
// Compute the i-th score
U score = 0;
for (int j = 0; j < elem_per_thread; j++) {
for (int j = 0; j < qk_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
@@ -84,14 +86,14 @@ template <typename T, int D>
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int j = 0; j < elem_per_thread; j++) {
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j];
}
}
// Move the pointers to the next kv
keys += stride;
values += stride;
keys += inner_k_stride;
values += inner_v_stride;
if (has_mask) {
mask += BN * mask_seq_stride;
}
@@ -111,7 +113,7 @@ template <typename T, int D>
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
@@ -120,13 +122,13 @@ template <typename T, int D>
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < v_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}
template <typename T, int D>
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
@@ -147,15 +149,17 @@ template <typename T, int D>
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
constexpr int qk_per_thread = D / BD;
constexpr int v_per_thread = V / BD;
constexpr int inner_k_stride = BN * D;
constexpr int inner_v_stride = BN * V;
constexpr int blocks = 32;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
thread U q[qk_per_thread];
thread U k[qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
@@ -165,12 +169,12 @@ template <typename T, int D>
const int block_idx = tid.z;
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
queries += head_idx * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
simd_lid * v_per_thread;
out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
@@ -179,10 +183,10 @@ template <typename T, int D>
maxs += head_idx * blocks + block_idx;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
}
@@ -193,13 +197,13 @@ template <typename T, int D>
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < qk_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
@@ -213,14 +217,14 @@ template <typename T, int D>
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < v_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
keys += blocks * inner_k_stride;
values += blocks * inner_v_stride;
if (has_mask) {
mask += BN * blocks * mask_seq_stride;
}
@@ -247,7 +251,7 @@ template <typename T, int D>
}
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@@ -124,6 +124,8 @@ void sdpa_vector(
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(v.shape(-1));
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
@@ -185,6 +187,8 @@ void sdpa_vector_2pass(
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(v.shape(-1));
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
@@ -256,7 +260,7 @@ void sdpa_vector_2pass(
kname += "sdpa_vector_2pass_2_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += std::to_string(v.shape(-1));
// Get the kernel
kernel = d.get_kernel(kname);
@@ -332,7 +336,7 @@ void ScaledDotProductAttention::eval_gpu(
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
if (q.is_donatable() && q.size() == o.size()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));

View File

@@ -684,23 +684,20 @@ array scaled_dot_product_attention(
const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = query_head_dim == value_head_dim;
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 80;
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supported_mask = !mask || (mask->dtype() == bool_);
const bool supports_sdpa_vector = query_sequence_length == 1 &&
supported_mask && sdpa_vector_supported_head_dim &&
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
implementation_supports_use_case &=
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};