mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv * fix flaky test * nit
This commit is contained in:
@@ -131,8 +131,11 @@ void sdpa_vector(
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
size_t k_head_stride = k.strides()[1];
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.strides()[1];
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||
|
||||
@@ -158,20 +161,23 @@ void sdpa_vector(
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(gqa_factor, 4);
|
||||
compute_encoder.set_bytes(N, 5);
|
||||
compute_encoder.set_bytes(k_stride, 6);
|
||||
compute_encoder.set_bytes(v_stride, 7);
|
||||
compute_encoder.set_bytes(scale, 8);
|
||||
compute_encoder.set_bytes(k_head_stride, 6);
|
||||
compute_encoder.set_bytes(k_seq_stride, 7);
|
||||
compute_encoder.set_bytes(v_head_stride, 8);
|
||||
compute_encoder.set_bytes(v_seq_stride, 9);
|
||||
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 9);
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
compute_encoder.set_bytes(kv_seq_stride, 10);
|
||||
compute_encoder.set_bytes(q_seq_stride, 11);
|
||||
compute_encoder.set_bytes(head_stride, 12);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 12);
|
||||
compute_encoder.set_bytes(q_seq_stride, 13);
|
||||
compute_encoder.set_bytes(head_stride, 14);
|
||||
}
|
||||
|
||||
// Launch
|
||||
@@ -202,8 +208,10 @@ void sdpa_vector_2pass(
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
auto k_stride = k.strides()[1];
|
||||
auto v_stride = v.strides()[1];
|
||||
size_t k_head_stride = k.strides()[1];
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.strides()[1];
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(B, q.shape(2), blocks);
|
||||
|
||||
@@ -250,20 +258,22 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_output_array(maxs, 5);
|
||||
compute_encoder.set_bytes(gqa_factor, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
compute_encoder.set_bytes(k_stride, 8);
|
||||
compute_encoder.set_bytes(v_stride, 9);
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
compute_encoder.set_bytes(k_head_stride, 8);
|
||||
compute_encoder.set_bytes(k_seq_stride, 9);
|
||||
compute_encoder.set_bytes(v_head_stride, 10);
|
||||
compute_encoder.set_bytes(v_seq_stride, 11);
|
||||
compute_encoder.set_bytes(scale, 12);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
compute_encoder.set_input_array(m, 13);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
compute_encoder.set_bytes(kv_seq_stride, 12);
|
||||
compute_encoder.set_bytes(q_seq_stride, 13);
|
||||
compute_encoder.set_bytes(head_stride, 14);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 14);
|
||||
compute_encoder.set_bytes(q_seq_stride, 15);
|
||||
compute_encoder.set_bytes(head_stride, 16);
|
||||
}
|
||||
|
||||
// Launch
|
||||
@@ -334,15 +344,6 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
||||
};
|
||||
|
||||
// Returns true if the array is row contiguous except the sequence length
|
||||
// dimension that can be sliced but with step=1.
|
||||
auto is_contiguous_except_seq_len = [](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return strides[3] == 1 && strides[2] == shape[3] &&
|
||||
strides[0] == strides[1] * shape[1];
|
||||
};
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(3) == 1;
|
||||
@@ -351,8 +352,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 8) {
|
||||
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
||||
|
||||
Reference in New Issue
Block a user