mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Allow different value dimensions in sdpa_vector (#1811)
This commit is contained in:
committed by
GitHub
parent
b7c9f1d38f
commit
f5cc1eea72
@@ -124,6 +124,8 @@ void sdpa_vector(
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
kname += "_";
|
||||
kname += std::to_string(v.shape(-1));
|
||||
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
@@ -185,6 +187,8 @@ void sdpa_vector_2pass(
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
kname += "_";
|
||||
kname += std::to_string(v.shape(-1));
|
||||
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
@@ -256,7 +260,7 @@ void sdpa_vector_2pass(
|
||||
kname += "sdpa_vector_2pass_2_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
kname += std::to_string(v.shape(-1));
|
||||
|
||||
// Get the kernel
|
||||
kernel = d.get_kernel(kname);
|
||||
@@ -332,7 +336,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable()) {
|
||||
if (q.is_donatable() && q.size() == o.size()) {
|
||||
o.move_shared_buffer(q);
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
Reference in New Issue
Block a user