Allow different value dimensions in sdpa_vector (#1811)

This commit is contained in:
Angelos Katharopoulos
2025-01-31 20:58:59 -08:00
committed by GitHub
parent b7c9f1d38f
commit f5cc1eea72
6 changed files with 127 additions and 72 deletions

View File

@@ -124,6 +124,8 @@ void sdpa_vector(
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(v.shape(-1));
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
@@ -185,6 +187,8 @@ void sdpa_vector_2pass(
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(v.shape(-1));
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
@@ -256,7 +260,7 @@ void sdpa_vector_2pass(
kname += "sdpa_vector_2pass_2_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += std::to_string(v.shape(-1));
// Get the kernel
kernel = d.get_kernel(kname);
@@ -332,7 +336,7 @@ void ScaledDotProductAttention::eval_gpu(
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
if (q.is_donatable() && q.size() == o.size()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));