mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv * fix flaky test * nit
This commit is contained in:
parent
cffceda6ee
commit
3c3e558c60
@ -15,8 +15,10 @@ template <typename T, int D, int V = D>
|
|||||||
device T* out [[buffer(3)]],
|
device T* out [[buffer(3)]],
|
||||||
const constant int& gqa_factor,
|
const constant int& gqa_factor,
|
||||||
const constant int& N,
|
const constant int& N,
|
||||||
const constant size_t& k_stride,
|
const constant size_t& k_head_stride,
|
||||||
const constant size_t& v_stride,
|
const constant size_t& k_seq_stride,
|
||||||
|
const constant size_t& v_head_stride,
|
||||||
|
const constant size_t& v_seq_stride,
|
||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
const device bool* mask [[function_constant(has_mask)]],
|
const device bool* mask [[function_constant(has_mask)]],
|
||||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||||
@ -30,8 +32,8 @@ template <typename T, int D, int V = D>
|
|||||||
constexpr int BD = 32;
|
constexpr int BD = 32;
|
||||||
constexpr int qk_per_thread = D / BD;
|
constexpr int qk_per_thread = D / BD;
|
||||||
constexpr int v_per_thread = V / BD;
|
constexpr int v_per_thread = V / BD;
|
||||||
constexpr int inner_k_stride = BN * D;
|
int inner_k_stride = BN * int(k_seq_stride);
|
||||||
constexpr int inner_v_stride = BN * V;
|
int inner_v_stride = BN * int(v_seq_stride);
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
@ -51,8 +53,10 @@ template <typename T, int D, int V = D>
|
|||||||
const int q_offset =
|
const int q_offset =
|
||||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
|
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||||
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
|
simd_lid * qk_per_thread;
|
||||||
|
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||||
|
simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||||
q_seq_idx * mask_q_seq_stride;
|
q_seq_idx * mask_q_seq_stride;
|
||||||
@ -147,8 +151,10 @@ template <typename T, int D, int V = D>
|
|||||||
device float* maxs [[buffer(5)]],
|
device float* maxs [[buffer(5)]],
|
||||||
const constant int& gqa_factor,
|
const constant int& gqa_factor,
|
||||||
const constant int& N,
|
const constant int& N,
|
||||||
const constant size_t& k_stride,
|
const constant size_t& k_head_stride,
|
||||||
const constant size_t& v_stride,
|
const constant size_t& k_seq_stride,
|
||||||
|
const constant size_t& v_head_stride,
|
||||||
|
const constant size_t& v_seq_stride,
|
||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
const device bool* mask [[function_constant(has_mask)]],
|
const device bool* mask [[function_constant(has_mask)]],
|
||||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||||
@ -162,8 +168,8 @@ template <typename T, int D, int V = D>
|
|||||||
constexpr int BD = 32;
|
constexpr int BD = 32;
|
||||||
constexpr int qk_per_thread = D / BD;
|
constexpr int qk_per_thread = D / BD;
|
||||||
constexpr int v_per_thread = V / BD;
|
constexpr int v_per_thread = V / BD;
|
||||||
constexpr int inner_k_stride = BN * D;
|
int inner_k_stride = BN * int(k_seq_stride);
|
||||||
constexpr int inner_v_stride = BN * V;
|
int inner_v_stride = BN * int(v_seq_stride);
|
||||||
constexpr int blocks = 32;
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
@ -186,10 +192,10 @@ template <typename T, int D, int V = D>
|
|||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
|
|
||||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
keys += kv_head_idx * k_head_stride +
|
||||||
simd_lid * qk_per_thread;
|
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
|
||||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
|
values += kv_head_idx * v_head_stride +
|
||||||
simd_lid * v_per_thread;
|
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += head_idx * mask_head_stride +
|
mask += head_idx * mask_head_stride +
|
||||||
|
@ -131,8 +131,11 @@ void sdpa_vector(
|
|||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
int N = k.shape(2);
|
int N = k.shape(2);
|
||||||
int B = q.shape(0) * q.shape(1);
|
int B = q.shape(0) * q.shape(1);
|
||||||
size_t k_stride = k.strides()[1];
|
size_t k_head_stride = k.strides()[1];
|
||||||
size_t v_stride = v.strides()[1];
|
size_t k_seq_stride = k.strides()[2];
|
||||||
|
size_t v_head_stride = v.strides()[1];
|
||||||
|
size_t v_seq_stride = v.strides()[2];
|
||||||
|
|
||||||
MTL::Size group_dims(1024, 1, 1);
|
MTL::Size group_dims(1024, 1, 1);
|
||||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||||
|
|
||||||
@ -158,20 +161,23 @@ void sdpa_vector(
|
|||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
compute_encoder.set_bytes(gqa_factor, 4);
|
compute_encoder.set_bytes(gqa_factor, 4);
|
||||||
compute_encoder.set_bytes(N, 5);
|
compute_encoder.set_bytes(N, 5);
|
||||||
compute_encoder.set_bytes(k_stride, 6);
|
compute_encoder.set_bytes(k_head_stride, 6);
|
||||||
compute_encoder.set_bytes(v_stride, 7);
|
compute_encoder.set_bytes(k_seq_stride, 7);
|
||||||
compute_encoder.set_bytes(scale, 8);
|
compute_encoder.set_bytes(v_head_stride, 8);
|
||||||
|
compute_encoder.set_bytes(v_seq_stride, 9);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(scale, 10);
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 9);
|
compute_encoder.set_input_array(m, 11);
|
||||||
auto nd = m.ndim();
|
auto nd = m.ndim();
|
||||||
int32_t kv_seq_stride =
|
int32_t kv_seq_stride =
|
||||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||||
compute_encoder.set_bytes(kv_seq_stride, 10);
|
compute_encoder.set_bytes(kv_seq_stride, 12);
|
||||||
compute_encoder.set_bytes(q_seq_stride, 11);
|
compute_encoder.set_bytes(q_seq_stride, 13);
|
||||||
compute_encoder.set_bytes(head_stride, 12);
|
compute_encoder.set_bytes(head_stride, 14);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
@ -202,8 +208,10 @@ void sdpa_vector_2pass(
|
|||||||
int N = k.shape(2);
|
int N = k.shape(2);
|
||||||
int blocks = 32;
|
int blocks = 32;
|
||||||
int B = q.shape(0) * q.shape(1);
|
int B = q.shape(0) * q.shape(1);
|
||||||
auto k_stride = k.strides()[1];
|
size_t k_head_stride = k.strides()[1];
|
||||||
auto v_stride = v.strides()[1];
|
size_t k_seq_stride = k.strides()[2];
|
||||||
|
size_t v_head_stride = v.strides()[1];
|
||||||
|
size_t v_seq_stride = v.strides()[2];
|
||||||
MTL::Size group_dims(8 * 32, 1, 1);
|
MTL::Size group_dims(8 * 32, 1, 1);
|
||||||
MTL::Size grid_dims(B, q.shape(2), blocks);
|
MTL::Size grid_dims(B, q.shape(2), blocks);
|
||||||
|
|
||||||
@ -250,20 +258,22 @@ void sdpa_vector_2pass(
|
|||||||
compute_encoder.set_output_array(maxs, 5);
|
compute_encoder.set_output_array(maxs, 5);
|
||||||
compute_encoder.set_bytes(gqa_factor, 6);
|
compute_encoder.set_bytes(gqa_factor, 6);
|
||||||
compute_encoder.set_bytes(N, 7);
|
compute_encoder.set_bytes(N, 7);
|
||||||
compute_encoder.set_bytes(k_stride, 8);
|
compute_encoder.set_bytes(k_head_stride, 8);
|
||||||
compute_encoder.set_bytes(v_stride, 9);
|
compute_encoder.set_bytes(k_seq_stride, 9);
|
||||||
compute_encoder.set_bytes(scale, 10);
|
compute_encoder.set_bytes(v_head_stride, 10);
|
||||||
|
compute_encoder.set_bytes(v_seq_stride, 11);
|
||||||
|
compute_encoder.set_bytes(scale, 12);
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 11);
|
compute_encoder.set_input_array(m, 13);
|
||||||
auto nd = m.ndim();
|
auto nd = m.ndim();
|
||||||
int32_t kv_seq_stride =
|
int32_t kv_seq_stride =
|
||||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||||
compute_encoder.set_bytes(kv_seq_stride, 12);
|
compute_encoder.set_bytes(kv_seq_stride, 14);
|
||||||
compute_encoder.set_bytes(q_seq_stride, 13);
|
compute_encoder.set_bytes(q_seq_stride, 15);
|
||||||
compute_encoder.set_bytes(head_stride, 14);
|
compute_encoder.set_bytes(head_stride, 16);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
@ -334,15 +344,6 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns true if the array is row contiguous except the sequence length
|
|
||||||
// dimension that can be sliced but with step=1.
|
|
||||||
auto is_contiguous_except_seq_len = [](const array& arr) {
|
|
||||||
auto& strides = arr.strides();
|
|
||||||
auto& shape = arr.shape();
|
|
||||||
return strides[3] == 1 && strides[2] == shape[3] &&
|
|
||||||
strides[0] == strides[1] * shape[1];
|
|
||||||
};
|
|
||||||
|
|
||||||
// Checks that the headdim dimension has stride 1.
|
// Checks that the headdim dimension has stride 1.
|
||||||
auto is_matrix_contiguous = [](const array& arr) {
|
auto is_matrix_contiguous = [](const array& arr) {
|
||||||
return arr.strides(3) == 1;
|
return arr.strides(3) == 1;
|
||||||
@ -351,8 +352,8 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) <= 8) {
|
if (q_pre.shape(2) <= 8) {
|
||||||
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
||||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||||
|
|
||||||
// Donate the query if possible
|
// Donate the query if possible
|
||||||
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
||||||
|
@ -183,9 +183,11 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
y = mx.distributed.all_sum(x)
|
y = mx.distributed.all_sum(x)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
all_sum_only = mx.metal.get_peak_memory()
|
all_sum_only = mx.metal.get_peak_memory()
|
||||||
y = mx.distributed.all_sum(x) * scale
|
y = mx.distributed.all_sum(x) * scale
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
all_sum_with_binary = mx.metal.get_peak_memory()
|
all_sum_with_binary = mx.metal.get_peak_memory()
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
@ -171,7 +171,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||||
|
|
||||||
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
||||||
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
||||||
v = mx.random.normal(shape=(1, 32, 128, Dk))
|
v = mx.random.normal(shape=(1, 32, 128, Dk))
|
||||||
@ -201,6 +200,38 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||||
|
|
||||||
|
def test_fast_sdpa_vector_kv_transposed_head_seq(self):
|
||||||
|
D = 64
|
||||||
|
Nq = 4
|
||||||
|
Nkv = 1
|
||||||
|
scale = 1.0
|
||||||
|
mx.random.seed(0)
|
||||||
|
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||||
|
|
||||||
|
lengths = [43, 4096]
|
||||||
|
for L in lengths:
|
||||||
|
k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
|
||||||
|
v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
|
||||||
|
k = k.swapaxes(1, 2)
|
||||||
|
v = v.swapaxes(1, 2)
|
||||||
|
masks = [
|
||||||
|
mx.array(True),
|
||||||
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
|
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for m in masks:
|
||||||
|
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale=scale,
|
||||||
|
mask=m,
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
def test_fast_sdpa_vector(self):
|
def test_fast_sdpa_vector(self):
|
||||||
D = 64
|
D = 64
|
||||||
L = 43
|
L = 43
|
||||||
@ -292,7 +323,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
return
|
|
||||||
L = 4096
|
L = 4096
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
mx.random.seed(0)
|
mx.random.seed(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user