mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix batched vector sdpa (#2152)
This commit is contained in:
parent
825124af8f
commit
af705590ac
@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||
simd_lid * qk_per_thread;
|
||||
@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
@ -358,8 +358,8 @@ template <typename T, int D>
|
||||
// Adjust positions
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int n_heads = tpg.x;
|
||||
const int q_offset = n_heads * q_seq_idx + head_idx;
|
||||
const int q_offset = head_idx * tpg.y + q_seq_idx;
|
||||
;
|
||||
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += q_offset * blocks;
|
||||
maxs += q_offset * blocks;
|
||||
|
@ -154,9 +154,9 @@ void sdpa_vector(
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_head_stride = k.strides()[1];
|
||||
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.strides()[1];
|
||||
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
@ -199,11 +199,10 @@ void sdpa_vector(
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
|
||||
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
|
||||
int32_t head_stride =
|
||||
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 13);
|
||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||
compute_encoder.set_bytes(head_stride, 15);
|
||||
@ -238,9 +237,10 @@ void sdpa_vector_2pass(
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_head_stride = k.strides()[1];
|
||||
|
||||
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.strides()[1];
|
||||
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(B, q.shape(2), blocks);
|
||||
@ -302,11 +302,10 @@ void sdpa_vector_2pass(
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 13 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
|
||||
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
|
||||
int32_t head_stride =
|
||||
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 15);
|
||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||
compute_encoder.set_bytes(head_stride, 17);
|
||||
@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
// Checks if arr is row contiguous or the sequence and head dimension are
|
||||
// transposed
|
||||
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
}
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
|
||||
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
||||
};
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(-1) == 1;
|
||||
@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 8) {
|
||||
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
}
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
if (shape[0] == 1 || shape[1] == 1) {
|
||||
// If either the batch or head dimension is a singleton, the other can
|
||||
// be transposed with the sequence dimension
|
||||
auto bidx = shape[0] == 1 ? 1 : 0;
|
||||
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
|
||||
(strides[bidx] == shape[3]);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto kv_copy_unless = [](const array& arr) {
|
||||
// keys and values should be copied if:
|
||||
// - the last dimension is not contiguous
|
||||
// - the batch and head dim are not contiguous
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
if (strides.back() != 1) {
|
||||
return false;
|
||||
}
|
||||
if (shape[0] == 1 || shape[1] == 1) {
|
||||
return true;
|
||||
}
|
||||
return (strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
const auto& q = copy_unless(q_copy_unless, q_pre);
|
||||
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
||||
q.size() == o.size()) {
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
if (o.shape(2) == 1) {
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
} else {
|
||||
auto strides = o.strides();
|
||||
strides[2] = o.shape(1) * o.shape(3);
|
||||
strides[1] = o.shape(3);
|
||||
auto flags = q.flags();
|
||||
flags.row_contiguous = q.shape(1) == 1;
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
||||
}
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
}
|
||||
|
||||
auto mask =
|
||||
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
|
||||
auto mask_copy_unless = [&q](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
|
||||
(strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
auto mask = inputs.size() > 3
|
||||
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
|
@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_vector_batched(self):
|
||||
D = 64
|
||||
q = mx.random.normal(shape=(2, 1, 3, D))
|
||||
k = mx.random.normal(shape=(2, 1, 3, D))
|
||||
v = mx.random.normal(shape=(2, 1, 3, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||
k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2)
|
||||
v = mx.random.normal(shape=(2, 2, 3, D))
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||
k = mx.random.normal(shape=(2, 1, 3, D))
|
||||
v = mx.random.normal(shape=(2, 1, 3, D))
|
||||
mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
class TestSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user