SDPA support for small batch (over sequence) queries (#1922)

* batch query sdpa

* batch sdpa for query
This commit is contained in:
Awni Hannun
2025-03-04 10:59:04 -08:00
committed by GitHub
parent 6bcd6bcf70
commit e613d0eaf0
5 changed files with 159 additions and 45 deletions

View File

@@ -134,14 +134,17 @@ void sdpa_vector(
size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value();
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -161,10 +164,14 @@ void sdpa_vector(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 9);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 10);
compute_encoder.set_bytes(head_stride, 11);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 10);
compute_encoder.set_bytes(q_seq_stride, 11);
compute_encoder.set_bytes(head_stride, 12);
}
// Launch
@@ -198,7 +205,7 @@ void sdpa_vector_2pass(
auto k_stride = k.strides()[1];
auto v_stride = v.strides()[1];
MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(1, B, blocks);
MTL::Size grid_dims(B, q.shape(2), blocks);
// Allocate the intermediates
Shape intermediate_shape;
@@ -219,11 +226,14 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -246,10 +256,14 @@ void sdpa_vector_2pass(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 12);
compute_encoder.set_bytes(head_stride, 13);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12);
compute_encoder.set_bytes(q_seq_stride, 13);
compute_encoder.set_bytes(head_stride, 14);
}
// Launch
@@ -274,7 +288,7 @@ void sdpa_vector_2pass(
// Launch
group_dims = MTL::Size(1024, 1, 1);
grid_dims = MTL::Size(1, B, 1);
grid_dims = MTL::Size(B, q.shape(2), 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -301,16 +315,23 @@ void ScaledDotProductAttention::eval_gpu(
if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {
return arr;
}
};
// Checks if arr is fully row contiguous
auto is_contiguous = [](const array& arr) {
return arr.flags().row_contiguous;
// Checks if arr is row contiguous or the sequence and head dimension are
// transposed
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
};
// Returns true if the array is row contiguous except the sequence length
@@ -328,18 +349,30 @@ void ScaledDotProductAttention::eval_gpu(
};
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
const auto& q = copy_unless(is_contiguous, q_pre);
// 1, heads, seq_len, head_dim
// mask [1, query_heads, 1, seq_len]
if (q_pre.shape(2) <= 8) {
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable() && q.size() == o.size()) {
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
q.size() == o.size()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
if (o.shape(2) == 1) {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.size(),
std::move(strides),
flags);
}
}
auto mask =