Allow different value dimensions in sdpa_vector (#1811)

This commit is contained in:
Angelos Katharopoulos 2025-01-31 20:58:59 -08:00 committed by GitHub
parent b7c9f1d38f
commit f5cc1eea72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 72 deletions

View File

@ -8,14 +8,23 @@ L = 16384
H = 32 H = 32
H_k = H // 4 H_k = H // 4
D = 128 D = 128
V = 128
dtype = mx.float16 dtype = mx.float16
loops = 10 loops = 10
def attention(q, k, v, mask=None): def upproject(x, w):
if w is None:
return x
else:
return x @ w.T
def attention(q, k, v, mask=None, w=None):
def _sdpa(q, k, v): def _sdpa(q, k, v):
B, Hq, L, D = q.shape B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape _, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D) q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :] k = k[:, :, None, :, :]
v = v[:, :, None, :, :] v = v[:, :, None, :, :]
@ -25,16 +34,18 @@ def attention(q, k, v, mask=None):
s = mx.where(m, s, mx.finfo(s.dtype).min) s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v o = p @ v
return o.reshape(B, Hq, L, D) return o.reshape(B, Hq, L, V)
for i in range(loops): for i in range(loops):
q = _sdpa(q, k, v) q = _sdpa(q, k, v)
q = upproject(q, w)
return q return q
def sdpa(q, k, v, mask=None): def sdpa(q, k, v, mask=None, w=None):
for i in range(loops): for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q = upproject(q, w)
return q return q
@ -42,34 +53,37 @@ def time_self_attention_primitives():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
mx.eval(q, k, v) w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
time_fn(attention, q, k, v) mx.eval(q, k, v, w)
time_fn(attention, q, k, v, w=w)
def time_self_attention_sdpa(): def time_self_attention_sdpa():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
mx.eval(q, k, v) w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
time_fn(sdpa, q, k, v) mx.eval(q, k, v, w)
time_fn(sdpa, q, k, v, w=w)
def time_self_attention_sdpa_with_mask(): def time_self_attention_sdpa_with_mask():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mask = mx.full((L,), True) mask = mx.full((L,), True)
mask[L // 2 :] = False mask[L // 2 :] = False
mx.eval(q, k, v, mask) mx.eval(q, k, v, mask, w)
def sdpa_mask(*args): def sdpa_mask(*args):
return sdpa(*args, mask=mask) return sdpa(*args, mask=mask, w=w)
def attention_mask(*args): def attention_mask(*args):
return attention(*args, mask=mask) return attention(*args, mask=mask, w=w)
time_fn(attention_mask, q, k, v) time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v) time_fn(sdpa_mask, q, k, v)

View File

@ -7,15 +7,34 @@ using namespace metal;
// clang-format off // clang-format off
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \ #define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ instantiate_kernel( \
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \ "sdpa_vector_2pass_2_" #type "_" #value_dim, \
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim) sdpa_vector_2pass_2, \
type, \
value_dim)
#define instantiate_sdpa_vector(type, qk_dim, value_dim) \
instantiate_kernel( \
"sdpa_vector_" #type "_" #qk_dim "_" #value_dim, \
sdpa_vector, \
type, \
qk_dim, \
value_dim) \
instantiate_kernel( \
"sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \
sdpa_vector_2pass_1, \
type, \
qk_dim, \
value_dim)
#define instantiate_sdpa_vector_heads(type) \ #define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \ instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96) \ instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128) instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128)
instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(bfloat16_t)

View File

@ -6,7 +6,7 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]]; constant bool has_mask [[function_constant(20)]];
template <typename T, int D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]], const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]], const device T* keys [[buffer(1)]],
@ -25,14 +25,16 @@ template <typename T, int D>
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 32;
constexpr int elem_per_thread = D / BD; constexpr int qk_per_thread = D / BD;
constexpr int stride = BN * D; constexpr int v_per_thread = V / BD;
constexpr int inner_k_stride = BN * D;
constexpr int inner_v_stride = BN * V;
typedef float U; typedef float U;
thread U q[elem_per_thread]; thread U q[qk_per_thread];
thread U k[elem_per_thread]; thread U k[qk_per_thread];
thread U o[elem_per_thread]; thread U o[v_per_thread];
threadgroup U outputs[BN * BD]; threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN]; threadgroup U max_scores[BN];
@ -41,19 +43,19 @@ template <typename T, int D>
// Adjust positions // Adjust positions
const int head_idx = tid.y; const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread; queries += head_idx * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
if (has_mask) { if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
} }
out += head_idx * D + simd_gid * elem_per_thread; out += head_idx * V + simd_gid * v_per_thread;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i]; q[i] = static_cast<U>(scale) * queries[i];
} }
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
o[i] = 0; o[i] = 0;
} }
@ -64,13 +66,13 @@ template <typename T, int D>
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) { if (!has_mask || mask[0]) {
// Read the key // Read the key
for (int j = 0; j < elem_per_thread; j++) { for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j]; k[j] = keys[j];
} }
// Compute the i-th score // Compute the i-th score
U score = 0; U score = 0;
for (int j = 0; j < elem_per_thread; j++) { for (int j = 0; j < qk_per_thread; j++) {
score += q[j] * k[j]; score += q[j] * k[j];
} }
score = simd_sum(score); score = simd_sum(score);
@ -84,14 +86,14 @@ template <typename T, int D>
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator // Update the output accumulator
for (int j = 0; j < elem_per_thread; j++) { for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j]; o[j] = o[j] * factor + exp_score * values[j];
} }
} }
// Move the pointers to the next kv // Move the pointers to the next kv
keys += stride; keys += inner_k_stride;
values += stride; values += inner_v_stride;
if (has_mask) { if (has_mask) {
mask += BN * mask_seq_stride; mask += BN * mask_seq_stride;
} }
@ -111,7 +113,7 @@ template <typename T, int D>
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs // Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i]; outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
@ -120,13 +122,13 @@ template <typename T, int D>
// And write the output // And write the output
if (simd_lid == 0) { if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
out[i] = static_cast<T>(o[i]); out[i] = static_cast<T>(o[i]);
} }
} }
} }
template <typename T, int D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector_2pass_1( [[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]], const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]], const device T* keys [[buffer(1)]],
@ -147,15 +149,17 @@ template <typename T, int D>
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8; constexpr int BN = 8;
constexpr int BD = 32; constexpr int BD = 32;
constexpr int elem_per_thread = D / BD; constexpr int qk_per_thread = D / BD;
constexpr int stride = BN * D; constexpr int v_per_thread = V / BD;
constexpr int inner_k_stride = BN * D;
constexpr int inner_v_stride = BN * V;
constexpr int blocks = 32; constexpr int blocks = 32;
typedef float U; typedef float U;
thread U q[elem_per_thread]; thread U q[qk_per_thread];
thread U k[elem_per_thread]; thread U k[qk_per_thread];
thread U o[elem_per_thread]; thread U o[v_per_thread];
threadgroup U outputs[BN * BD]; threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN]; threadgroup U max_scores[BN];
@ -165,12 +169,12 @@ template <typename T, int D>
const int block_idx = tid.z; const int block_idx = tid.z;
const int head_idx = tid.y; const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread; queries += head_idx * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
simd_lid * elem_per_thread; simd_lid * v_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) { if (has_mask) {
mask += head_idx * mask_head_stride + mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride; (block_idx * BN + simd_gid) * mask_seq_stride;
@ -179,10 +183,10 @@ template <typename T, int D>
maxs += head_idx * blocks + block_idx; maxs += head_idx * blocks + block_idx;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i]; q[i] = static_cast<U>(scale) * queries[i];
} }
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
o[i] = 0; o[i] = 0;
} }
@ -193,13 +197,13 @@ template <typename T, int D>
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) { if (!has_mask || mask[0]) {
// Read the key // Read the key
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i]; k[i] = keys[i];
} }
// Compute the i-th score // Compute the i-th score
U score = 0; U score = 0;
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
@ -213,14 +217,14 @@ template <typename T, int D>
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator // Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i]; o[i] = o[i] * factor + exp_score * values[i];
} }
} }
// Move the pointers to the next kv // Move the pointers to the next kv
keys += blocks * stride; keys += blocks * inner_k_stride;
values += blocks * stride; values += blocks * inner_v_stride;
if (has_mask) { if (has_mask) {
mask += BN * blocks * mask_seq_stride; mask += BN * blocks * mask_seq_stride;
} }
@ -247,7 +251,7 @@ template <typename T, int D>
} }
// Now we need to aggregate all the outputs // Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] = outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max); o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@ -124,6 +124,8 @@ void sdpa_vector(
kname += get_type_string(q.dtype()); kname += get_type_string(q.dtype());
kname += "_"; kname += "_";
kname += std::to_string(q.shape(-1)); kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(v.shape(-1));
// Compute the necessary sizes // Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1); int gqa_factor = q.shape(1) / k.shape(1);
@ -185,6 +187,8 @@ void sdpa_vector_2pass(
kname += get_type_string(q.dtype()); kname += get_type_string(q.dtype());
kname += "_"; kname += "_";
kname += std::to_string(q.shape(-1)); kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(v.shape(-1));
// Compute the necessary sizes // Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1); int gqa_factor = q.shape(1) / k.shape(1);
@ -256,7 +260,7 @@ void sdpa_vector_2pass(
kname += "sdpa_vector_2pass_2_"; kname += "sdpa_vector_2pass_2_";
kname += get_type_string(q.dtype()); kname += get_type_string(q.dtype());
kname += "_"; kname += "_";
kname += std::to_string(q.shape(-1)); kname += std::to_string(v.shape(-1));
// Get the kernel // Get the kernel
kernel = d.get_kernel(kname); kernel = d.get_kernel(kname);
@ -332,7 +336,7 @@ void ScaledDotProductAttention::eval_gpu(
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable()) { if (q.is_donatable() && q.size() == o.size()) {
o.move_shared_buffer(q); o.move_shared_buffer(q);
} else { } else {
o.set_data(allocator::malloc_or_wait(o.nbytes())); o.set_data(allocator::malloc_or_wait(o.nbytes()));

View File

@ -684,23 +684,20 @@ array scaled_dot_product_attention(
const size_t query_head_dim = q.shape(-1); const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2); const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = query_head_dim == value_head_dim;
const bool sdpa_vector_supported_head_dim = const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; query_head_dim == value_head_dim &&
const bool sdpa_full_supported_head_dim = (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
query_head_dim == 64 || query_head_dim == 80; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80);
const bool supports_sdpa_full = query_sequence_length >= threshold && const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
!mask.has_value() && sdpa_full_supported_head_dim && sdpa_full_supported_head_dim && stream.device == Device::gpu;
stream.device == Device::gpu;
const bool supported_mask = !mask || (mask->dtype() == bool_);
const bool supports_sdpa_vector = query_sequence_length == 1 && const bool supports_sdpa_vector = query_sequence_length == 1 &&
supported_mask && sdpa_vector_supported_head_dim && (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu; stream.device == Device::gpu;
implementation_supports_use_case &= const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector; supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v}; std::vector<array> inputs = {q, k, v};

View File

@ -262,6 +262,23 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
@unittest.skip("Different head and value dims is not enabled")
def test_fast_sdpa_vector_value_dims(self):
D = 192
V = 128
Nq = 4
Nkv = 1
scale = 1.0
mx.random.seed(0)
for L in [43, 128, 237, 8192]:
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, V))
ref = mlx_primitives_sdpa(q, k, v, scale)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(failfast=True) unittest.main(failfast=True)