SDPA support for small batch (over sequence) queries (#1922)

* batch query sdpa

* batch sdpa for query
This commit is contained in:
Awni Hannun 2025-03-04 10:59:04 -08:00 committed by GitHub
parent 6bcd6bcf70
commit e613d0eaf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 159 additions and 45 deletions

View File

@ -5,6 +5,7 @@
using namespace metal; using namespace metal;
constant bool has_mask [[function_constant(20)]]; constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@ -18,9 +19,11 @@ template <typename T, int D, int V = D>
const constant size_t& v_stride, const constant size_t& v_stride,
const constant float& scale, const constant float& scale,
const device bool* mask [[function_constant(has_mask)]], const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]], const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32; constexpr int BN = 32;
@ -41,15 +44,21 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN]; threadgroup U sum_exp_scores[BN];
// Adjust positions // Adjust positions
const int head_idx = tid.y; const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * qk_per_thread; const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread; values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
if (has_mask) { if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
} }
out += head_idx * V + simd_gid * v_per_thread;
out += o_offset * V + simd_gid * v_per_thread;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
for (int i = 0; i < qk_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
@ -95,7 +104,7 @@ template <typename T, int D, int V = D>
keys += inner_k_stride; keys += inner_k_stride;
values += inner_v_stride; values += inner_v_stride;
if (has_mask) { if (has_mask) {
mask += BN * mask_seq_stride; mask += BN * mask_kv_seq_stride;
} }
} }
@ -142,9 +151,11 @@ template <typename T, int D, int V = D>
const constant size_t& v_stride, const constant size_t& v_stride,
const constant float& scale, const constant float& scale,
const device bool* mask [[function_constant(has_mask)]], const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]], const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8; constexpr int BN = 8;
@ -167,20 +178,26 @@ template <typename T, int D, int V = D>
// Adjust positions // Adjust positions
const int block_idx = tid.z; const int block_idx = tid.z;
const int head_idx = tid.y; const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * qk_per_thread;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
simd_lid * v_per_thread; simd_lid * v_per_thread;
out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) { if (has_mask) {
mask += head_idx * mask_head_stride + mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride; (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
} }
sums += head_idx * blocks + block_idx; sums += o_offset * blocks + block_idx;
maxs += head_idx * blocks + block_idx; maxs += o_offset * blocks + block_idx;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
for (int i = 0; i < qk_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
@ -226,7 +243,7 @@ template <typename T, int D, int V = D>
keys += blocks * inner_k_stride; keys += blocks * inner_k_stride;
values += blocks * inner_v_stride; values += blocks * inner_v_stride;
if (has_mask) { if (has_mask) {
mask += BN * blocks * mask_seq_stride; mask += BN * blocks * mask_kv_seq_stride;
} }
} }
@ -275,6 +292,7 @@ template <typename T, int D>
const device float* maxs [[buffer(2)]], const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32; constexpr int BN = 32;
@ -288,11 +306,14 @@ template <typename T, int D>
threadgroup U outputs[BN * BD]; threadgroup U outputs[BN * BD];
// Adjust positions // Adjust positions
const int head_idx = tid.y; const int head_idx = tid.x;
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread; const int q_seq_idx = tid.y;
sums += head_idx * blocks; const int n_heads = tpg.x;
maxs += head_idx * blocks; const int q_offset = n_heads * q_seq_idx + head_idx;
out += head_idx * D + simd_gid * elem_per_thread; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;
out += q_offset * D + simd_gid * elem_per_thread;
// First everybody reads the max and sum_exp // First everybody reads the max and sum_exp
U max_score = maxs[simd_lid]; U max_score = maxs[simd_lid];

View File

@ -25,6 +25,10 @@ void RoPE::eval_gpu(
size_t out_strides[3]; size_t out_strides[3];
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim();
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1); size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) { if (dims_ < in.shape(-1)) {
donated = true; donated = true;
@ -44,12 +48,12 @@ void RoPE::eval_gpu(
strides[0] = mat_size; strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2]; strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1]; strides[2] = in.strides()[ndim - 1];
} else if (ndim == 3) { } else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs // Handle non-contiguous 3D inputs
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
strides[0] = in.strides()[0]; strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[1]; strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[2]; strides[2] = in.strides()[ndim - 1];
} else { } else {
// Copy non-contiguous > 3D inputs into the output and treat // Copy non-contiguous > 3D inputs into the output and treat
// input as donated // input as donated

View File

@ -134,14 +134,17 @@ void sdpa_vector(
size_t k_stride = k.strides()[1]; size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1]; size_t v_stride = v.strides()[1];
MTL::Size group_dims(1024, 1, 1); MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1); MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? "_mask" : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@ -161,10 +164,14 @@ void sdpa_vector(
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 9); compute_encoder.set_input_array(m, 9);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; auto nd = m.ndim();
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; int32_t kv_seq_stride =
compute_encoder.set_bytes(seq_stride, 10); nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
compute_encoder.set_bytes(head_stride, 11); int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 10);
compute_encoder.set_bytes(q_seq_stride, 11);
compute_encoder.set_bytes(head_stride, 12);
} }
// Launch // Launch
@ -198,7 +205,7 @@ void sdpa_vector_2pass(
auto k_stride = k.strides()[1]; auto k_stride = k.strides()[1];
auto v_stride = v.strides()[1]; auto v_stride = v.strides()[1];
MTL::Size group_dims(8 * 32, 1, 1); MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(1, B, blocks); MTL::Size grid_dims(B, q.shape(2), blocks);
// Allocate the intermediates // Allocate the intermediates
Shape intermediate_shape; Shape intermediate_shape;
@ -219,11 +226,14 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index); d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? "_mask" : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@ -246,10 +256,14 @@ void sdpa_vector_2pass(
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 11); compute_encoder.set_input_array(m, 11);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; auto nd = m.ndim();
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; int32_t kv_seq_stride =
compute_encoder.set_bytes(seq_stride, 12); nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
compute_encoder.set_bytes(head_stride, 13); int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12);
compute_encoder.set_bytes(q_seq_stride, 13);
compute_encoder.set_bytes(head_stride, 14);
} }
// Launch // Launch
@ -274,7 +288,7 @@ void sdpa_vector_2pass(
// Launch // Launch
group_dims = MTL::Size(1024, 1, 1); group_dims = MTL::Size(1024, 1, 1);
grid_dims = MTL::Size(1, B, 1); grid_dims = MTL::Size(B, q.shape(2), 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
@ -301,16 +315,23 @@ void ScaledDotProductAttention::eval_gpu(
if (!predicate(arr)) { if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s); copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); copies.push_back(std::move(arr_copy));
return copies.back(); return copies.back();
} else { } else {
return arr; return arr;
} }
}; };
// Checks if arr is fully row contiguous // Checks if arr is row contiguous or the sequence and head dimension are
auto is_contiguous = [](const array& arr) { // transposed
return arr.flags().row_contiguous; auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
}; };
// Returns true if the array is row contiguous except the sequence length // Returns true if the array is row contiguous except the sequence length
@ -328,18 +349,30 @@ void ScaledDotProductAttention::eval_gpu(
}; };
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) == 1) { if (q_pre.shape(2) <= 8) {
const auto& q = copy_unless(is_contiguous, q_pre); const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
// 1, heads, seq_len, head_dim
// mask [1, query_heads, 1, seq_len]
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable() && q.size() == o.size()) { if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
q.size() == o.size()) {
o.move_shared_buffer(q); o.move_shared_buffer(q);
} else { } else {
o.set_data(allocator::malloc_or_wait(o.nbytes())); if (o.shape(2) == 1) {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.size(),
std::move(strides),
flags);
}
} }
auto mask = auto mask =

View File

@ -715,7 +715,8 @@ array scaled_dot_product_attention(
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu; sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supports_sdpa_vector = query_sequence_length == 1 && const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= k.shape(-2)) &&
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim && (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu; stream.device == Device::gpu;

View File

@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_few_query(self):
D = 64
L = 43
Lq = 4
Nq = 8
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))
q = q.swapaxes(1, 2)
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
return
L = 4096
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
@unittest.skip("Different head and value dims is not enabled") @unittest.skip("Different head and value dims is not enabled")
def test_fast_sdpa_vector_value_dims(self): def test_fast_sdpa_vector_value_dims(self):
D = 192 D = 192