fix batched vector sdpa (#2152)

This commit is contained in:
Awni Hannun 2025-05-05 13:13:03 -07:00 committed by GitHub
parent 825124af8f
commit af705590ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 105 additions and 50 deletions

View File

@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
const int head_idx = tid.x; const int head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = tpg.x * q_seq_idx + head_idx; const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
const int block_idx = tid.z; const int block_idx = tid.z;
const int head_idx = tid.x; const int head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx; const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
@ -358,8 +358,8 @@ template <typename T, int D>
// Adjust positions // Adjust positions
const int head_idx = tid.x; const int head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int n_heads = tpg.x; const int q_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset = n_heads * q_seq_idx + head_idx; ;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks; sums += q_offset * blocks;
maxs += q_offset * blocks; maxs += q_offset * blocks;

View File

@ -154,9 +154,9 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1); int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2); int N = k.shape(2);
int B = q.shape(0) * q.shape(1); int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1]; size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2]; size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2]; size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1); MTL::Size group_dims(1024, 1, 1);
@ -199,11 +199,10 @@ void sdpa_vector(
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 11 + float_mask); compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim(); int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t kv_seq_stride = int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t head_stride =
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15); compute_encoder.set_bytes(head_stride, 15);
@ -238,9 +237,10 @@ void sdpa_vector_2pass(
int N = k.shape(2); int N = k.shape(2);
int blocks = 32; int blocks = 32;
int B = q.shape(0) * q.shape(1); int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1];
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2]; size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2]; size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(8 * 32, 1, 1); MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(B, q.shape(2), blocks); MTL::Size grid_dims(B, q.shape(2), blocks);
@ -302,11 +302,10 @@ void sdpa_vector_2pass(
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 13 + float_mask); compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim(); int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t kv_seq_stride = int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t head_stride =
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17); compute_encoder.set_bytes(head_stride, 17);
@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
} }
}; };
// Checks if arr is row contiguous or the sequence and head dimension are
// transposed
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
};
// Checks that the headdim dimension has stride 1. // Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) { auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1; return arr.strides(-1) == 1;
@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) <= 8) { if (q_pre.shape(2) <= 8) {
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); auto q_copy_unless = [](const array& arr) {
const auto& k = copy_unless(is_matrix_contiguous, k_pre); if (arr.flags().row_contiguous) {
const auto& v = copy_unless(is_matrix_contiguous, v_pre); return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
q.size() == o.size()) {
o.copy_shared_buffer(q); o.copy_shared_buffer(q);
} else { } else {
if (o.shape(2) == 1) { o.set_data(allocator::malloc(o.nbytes()));
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
} }
auto mask = auto mask_copy_unless = [&q](const array& arr) {
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt; auto& strides = arr.strides();
auto& shape = arr.shape();
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
// We route to the 2 pass fused attention if // We route to the 2 pass fused attention if
// - The device is large and the sequence length long // - The device is large and the sequence length long

View File

@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_vector_batched(self):
D = 64
q = mx.random.normal(shape=(2, 1, 3, D))
k = mx.random.normal(shape=(2, 1, 3, D))
v = mx.random.normal(shape=(2, 1, 3, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 4, 3, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 4, 3, D))
k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2)
v = mx.random.normal(shape=(2, 2, 3, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 4, 3, D))
k = mx.random.normal(shape=(2, 1, 3, D))
v = mx.random.normal(shape=(2, 1, 3, D))
mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
ref = mlx_ref_attn(q, k, v, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
class TestSDPA(mlx_tests.MLXTestCase): class TestSDPA(mlx_tests.MLXTestCase):
@property @property