mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix batched vector sdpa (#2152)
This commit is contained in:
parent
825124af8f
commit
af705590ac
@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
|
|||||||
const int head_idx = tid.x;
|
const int head_idx = tid.x;
|
||||||
const int q_seq_idx = tid.y;
|
const int q_seq_idx = tid.y;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||||
const int q_offset =
|
const int q_offset =
|
||||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||||
simd_lid * qk_per_thread;
|
simd_lid * qk_per_thread;
|
||||||
@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
|
|||||||
const int block_idx = tid.z;
|
const int block_idx = tid.z;
|
||||||
const int head_idx = tid.x;
|
const int head_idx = tid.x;
|
||||||
const int q_seq_idx = tid.y;
|
const int q_seq_idx = tid.y;
|
||||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||||
const int q_offset =
|
const int q_offset =
|
||||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
|
|
||||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||||
@ -358,8 +358,8 @@ template <typename T, int D>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int head_idx = tid.x;
|
const int head_idx = tid.x;
|
||||||
const int q_seq_idx = tid.y;
|
const int q_seq_idx = tid.y;
|
||||||
const int n_heads = tpg.x;
|
const int q_offset = head_idx * tpg.y + q_seq_idx;
|
||||||
const int q_offset = n_heads * q_seq_idx + head_idx;
|
;
|
||||||
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
sums += q_offset * blocks;
|
sums += q_offset * blocks;
|
||||||
maxs += q_offset * blocks;
|
maxs += q_offset * blocks;
|
||||||
|
@ -154,9 +154,9 @@ void sdpa_vector(
|
|||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
int N = k.shape(2);
|
int N = k.shape(2);
|
||||||
int B = q.shape(0) * q.shape(1);
|
int B = q.shape(0) * q.shape(1);
|
||||||
size_t k_head_stride = k.strides()[1];
|
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||||
size_t k_seq_stride = k.strides()[2];
|
size_t k_seq_stride = k.strides()[2];
|
||||||
size_t v_head_stride = v.strides()[1];
|
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||||
size_t v_seq_stride = v.strides()[2];
|
size_t v_seq_stride = v.strides()[2];
|
||||||
|
|
||||||
MTL::Size group_dims(1024, 1, 1);
|
MTL::Size group_dims(1024, 1, 1);
|
||||||
@ -199,11 +199,10 @@ void sdpa_vector(
|
|||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 11 + float_mask);
|
compute_encoder.set_input_array(m, 11 + float_mask);
|
||||||
auto nd = m.ndim();
|
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
|
||||||
int32_t kv_seq_stride =
|
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
|
||||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
int32_t head_stride =
|
||||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
|
||||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
|
||||||
compute_encoder.set_bytes(kv_seq_stride, 13);
|
compute_encoder.set_bytes(kv_seq_stride, 13);
|
||||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||||
compute_encoder.set_bytes(head_stride, 15);
|
compute_encoder.set_bytes(head_stride, 15);
|
||||||
@ -238,9 +237,10 @@ void sdpa_vector_2pass(
|
|||||||
int N = k.shape(2);
|
int N = k.shape(2);
|
||||||
int blocks = 32;
|
int blocks = 32;
|
||||||
int B = q.shape(0) * q.shape(1);
|
int B = q.shape(0) * q.shape(1);
|
||||||
size_t k_head_stride = k.strides()[1];
|
|
||||||
|
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||||
size_t k_seq_stride = k.strides()[2];
|
size_t k_seq_stride = k.strides()[2];
|
||||||
size_t v_head_stride = v.strides()[1];
|
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||||
size_t v_seq_stride = v.strides()[2];
|
size_t v_seq_stride = v.strides()[2];
|
||||||
MTL::Size group_dims(8 * 32, 1, 1);
|
MTL::Size group_dims(8 * 32, 1, 1);
|
||||||
MTL::Size grid_dims(B, q.shape(2), blocks);
|
MTL::Size grid_dims(B, q.shape(2), blocks);
|
||||||
@ -302,11 +302,10 @@ void sdpa_vector_2pass(
|
|||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 13 + float_mask);
|
compute_encoder.set_input_array(m, 13 + float_mask);
|
||||||
auto nd = m.ndim();
|
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
|
||||||
int32_t kv_seq_stride =
|
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
|
||||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
int32_t head_stride =
|
||||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
|
||||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
|
||||||
compute_encoder.set_bytes(kv_seq_stride, 15);
|
compute_encoder.set_bytes(kv_seq_stride, 15);
|
||||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||||
compute_encoder.set_bytes(head_stride, 17);
|
compute_encoder.set_bytes(head_stride, 17);
|
||||||
@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Checks if arr is row contiguous or the sequence and head dimension are
|
|
||||||
// transposed
|
|
||||||
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
|
|
||||||
if (arr.flags().row_contiguous) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
auto& strides = arr.strides();
|
|
||||||
auto& shape = arr.shape();
|
|
||||||
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
|
|
||||||
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Checks that the headdim dimension has stride 1.
|
// Checks that the headdim dimension has stride 1.
|
||||||
auto is_matrix_contiguous = [](const array& arr) {
|
auto is_matrix_contiguous = [](const array& arr) {
|
||||||
return arr.strides(-1) == 1;
|
return arr.strides(-1) == 1;
|
||||||
@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) <= 8) {
|
if (q_pre.shape(2) <= 8) {
|
||||||
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
auto q_copy_unless = [](const array& arr) {
|
||||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
if (arr.flags().row_contiguous) {
|
||||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
return true;
|
||||||
|
}
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
if (shape[0] == 1 || shape[1] == 1) {
|
||||||
|
// If either the batch or head dimension is a singleton, the other can
|
||||||
|
// be transposed with the sequence dimension
|
||||||
|
auto bidx = shape[0] == 1 ? 1 : 0;
|
||||||
|
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
|
||||||
|
(strides[bidx] == shape[3]);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto kv_copy_unless = [](const array& arr) {
|
||||||
|
// keys and values should be copied if:
|
||||||
|
// - the last dimension is not contiguous
|
||||||
|
// - the batch and head dim are not contiguous
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
if (strides.back() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[0] == 1 || shape[1] == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return (strides[0] == strides[1] * shape[1]);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto& q = copy_unless(q_copy_unless, q_pre);
|
||||||
|
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||||
|
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||||
|
|
||||||
// Donate the query if possible
|
// Donate the query if possible
|
||||||
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||||
q.size() == o.size()) {
|
|
||||||
o.copy_shared_buffer(q);
|
o.copy_shared_buffer(q);
|
||||||
} else {
|
} else {
|
||||||
if (o.shape(2) == 1) {
|
o.set_data(allocator::malloc(o.nbytes()));
|
||||||
o.set_data(allocator::malloc(o.nbytes()));
|
|
||||||
} else {
|
|
||||||
auto strides = o.strides();
|
|
||||||
strides[2] = o.shape(1) * o.shape(3);
|
|
||||||
strides[1] = o.shape(3);
|
|
||||||
auto flags = q.flags();
|
|
||||||
flags.row_contiguous = q.shape(1) == 1;
|
|
||||||
o.set_data(
|
|
||||||
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mask =
|
auto mask_copy_unless = [&q](const array& arr) {
|
||||||
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
|
||||||
|
(strides[0] == strides[1] * shape[1]);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto mask = inputs.size() > 3
|
||||||
|
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
|
||||||
|
: std::nullopt;
|
||||||
|
|
||||||
// We route to the 2 pass fused attention if
|
// We route to the 2 pass fused attention if
|
||||||
// - The device is large and the sequence length long
|
// - The device is large and the sequence length long
|
||||||
|
@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_sdpa_vector_batched(self):
|
||||||
|
D = 64
|
||||||
|
q = mx.random.normal(shape=(2, 1, 3, D))
|
||||||
|
k = mx.random.normal(shape=(2, 1, 3, D))
|
||||||
|
v = mx.random.normal(shape=(2, 1, 3, D))
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||||
|
ref = mlx_ref_attn(q, k, v)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||||
|
ref = mlx_ref_attn(q, k, v)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||||
|
ref = mlx_ref_attn(q, k, v)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||||
|
ref = mlx_ref_attn(q, k, v)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||||
|
k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2)
|
||||||
|
v = mx.random.normal(shape=(2, 2, 3, D))
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||||
|
ref = mlx_ref_attn(q, k, v)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||||
|
k = mx.random.normal(shape=(2, 1, 3, D))
|
||||||
|
v = mx.random.normal(shape=(2, 1, 3, D))
|
||||||
|
mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||||
|
ref = mlx_ref_attn(q, k, v, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
class TestSDPA(mlx_tests.MLXTestCase):
|
class TestSDPA(mlx_tests.MLXTestCase):
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user