mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa * batch sdpa for query
This commit is contained in:
parent
6bcd6bcf70
commit
e613d0eaf0
@ -5,6 +5,7 @@
|
|||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
constant bool has_mask [[function_constant(20)]];
|
constant bool has_mask [[function_constant(20)]];
|
||||||
|
constant bool query_transposed [[function_constant(21)]];
|
||||||
|
|
||||||
template <typename T, int D, int V = D>
|
template <typename T, int D, int V = D>
|
||||||
[[kernel]] void sdpa_vector(
|
[[kernel]] void sdpa_vector(
|
||||||
@ -18,9 +19,11 @@ template <typename T, int D, int V = D>
|
|||||||
const constant size_t& v_stride,
|
const constant size_t& v_stride,
|
||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
const device bool* mask [[function_constant(has_mask)]],
|
const device bool* mask [[function_constant(has_mask)]],
|
||||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||||
|
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpg [[threadgroups_per_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
@ -41,15 +44,21 @@ template <typename T, int D, int V = D>
|
|||||||
threadgroup U sum_exp_scores[BN];
|
threadgroup U sum_exp_scores[BN];
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int head_idx = tid.y;
|
const int head_idx = tid.x;
|
||||||
|
const int q_seq_idx = tid.y;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
queries += head_idx * D + simd_lid * qk_per_thread;
|
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||||
|
const int q_offset =
|
||||||
|
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||||
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
|
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
|
||||||
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
|
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||||
|
q_seq_idx * mask_q_seq_stride;
|
||||||
}
|
}
|
||||||
out += head_idx * V + simd_gid * v_per_thread;
|
|
||||||
|
out += o_offset * V + simd_gid * v_per_thread;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
for (int i = 0; i < qk_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
@ -95,7 +104,7 @@ template <typename T, int D, int V = D>
|
|||||||
keys += inner_k_stride;
|
keys += inner_k_stride;
|
||||||
values += inner_v_stride;
|
values += inner_v_stride;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += BN * mask_seq_stride;
|
mask += BN * mask_kv_seq_stride;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,9 +151,11 @@ template <typename T, int D, int V = D>
|
|||||||
const constant size_t& v_stride,
|
const constant size_t& v_stride,
|
||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
const device bool* mask [[function_constant(has_mask)]],
|
const device bool* mask [[function_constant(has_mask)]],
|
||||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||||
|
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpg [[threadgroups_per_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int BN = 8;
|
constexpr int BN = 8;
|
||||||
@ -167,20 +178,26 @@ template <typename T, int D, int V = D>
|
|||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int block_idx = tid.z;
|
const int block_idx = tid.z;
|
||||||
const int head_idx = tid.y;
|
const int head_idx = tid.x;
|
||||||
|
const int q_seq_idx = tid.y;
|
||||||
|
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||||
|
const int q_offset =
|
||||||
|
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
queries += head_idx * D + simd_lid * qk_per_thread;
|
|
||||||
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||||
simd_lid * qk_per_thread;
|
simd_lid * qk_per_thread;
|
||||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
|
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
|
||||||
simd_lid * v_per_thread;
|
simd_lid * v_per_thread;
|
||||||
out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += head_idx * mask_head_stride +
|
mask += head_idx * mask_head_stride +
|
||||||
(block_idx * BN + simd_gid) * mask_seq_stride;
|
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||||
|
q_seq_idx * mask_q_seq_stride;
|
||||||
}
|
}
|
||||||
sums += head_idx * blocks + block_idx;
|
sums += o_offset * blocks + block_idx;
|
||||||
maxs += head_idx * blocks + block_idx;
|
maxs += o_offset * blocks + block_idx;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
for (int i = 0; i < qk_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
@ -226,7 +243,7 @@ template <typename T, int D, int V = D>
|
|||||||
keys += blocks * inner_k_stride;
|
keys += blocks * inner_k_stride;
|
||||||
values += blocks * inner_v_stride;
|
values += blocks * inner_v_stride;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += BN * blocks * mask_seq_stride;
|
mask += BN * blocks * mask_kv_seq_stride;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,6 +292,7 @@ template <typename T, int D>
|
|||||||
const device float* maxs [[buffer(2)]],
|
const device float* maxs [[buffer(2)]],
|
||||||
device T* out [[buffer(3)]],
|
device T* out [[buffer(3)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpg [[threadgroups_per_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
@ -288,11 +306,14 @@ template <typename T, int D>
|
|||||||
threadgroup U outputs[BN * BD];
|
threadgroup U outputs[BN * BD];
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int head_idx = tid.y;
|
const int head_idx = tid.x;
|
||||||
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
const int q_seq_idx = tid.y;
|
||||||
sums += head_idx * blocks;
|
const int n_heads = tpg.x;
|
||||||
maxs += head_idx * blocks;
|
const int q_offset = n_heads * q_seq_idx + head_idx;
|
||||||
out += head_idx * D + simd_gid * elem_per_thread;
|
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
|
sums += q_offset * blocks;
|
||||||
|
maxs += q_offset * blocks;
|
||||||
|
out += q_offset * D + simd_gid * elem_per_thread;
|
||||||
|
|
||||||
// First everybody reads the max and sum_exp
|
// First everybody reads the max and sum_exp
|
||||||
U max_score = maxs[simd_lid];
|
U max_score = maxs[simd_lid];
|
||||||
|
@ -25,6 +25,10 @@ void RoPE::eval_gpu(
|
|||||||
size_t out_strides[3];
|
size_t out_strides[3];
|
||||||
bool donated = false;
|
bool donated = false;
|
||||||
int ndim = in.ndim();
|
int ndim = in.ndim();
|
||||||
|
int dispatch_ndim = in.ndim();
|
||||||
|
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||||
|
dispatch_ndim--;
|
||||||
|
}
|
||||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||||
if (dims_ < in.shape(-1)) {
|
if (dims_ < in.shape(-1)) {
|
||||||
donated = true;
|
donated = true;
|
||||||
@ -44,12 +48,12 @@ void RoPE::eval_gpu(
|
|||||||
strides[0] = mat_size;
|
strides[0] = mat_size;
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
} else if (ndim == 3) {
|
} else if (dispatch_ndim == 3) {
|
||||||
// Handle non-contiguous 3D inputs
|
// Handle non-contiguous 3D inputs
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
strides[0] = in.strides()[0];
|
strides[0] = in.strides()[ndim - 3];
|
||||||
strides[1] = in.strides()[1];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[2];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
} else {
|
} else {
|
||||||
// Copy non-contiguous > 3D inputs into the output and treat
|
// Copy non-contiguous > 3D inputs into the output and treat
|
||||||
// input as donated
|
// input as donated
|
||||||
|
@ -134,14 +134,17 @@ void sdpa_vector(
|
|||||||
size_t k_stride = k.strides()[1];
|
size_t k_stride = k.strides()[1];
|
||||||
size_t v_stride = v.strides()[1];
|
size_t v_stride = v.strides()[1];
|
||||||
MTL::Size group_dims(1024, 1, 1);
|
MTL::Size group_dims(1024, 1, 1);
|
||||||
MTL::Size grid_dims(1, B, 1);
|
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||||
|
|
||||||
bool has_mask = mask.has_value();
|
bool has_mask = mask.has_value();
|
||||||
|
bool query_transposed = !q.flags().row_contiguous;
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||||
|
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||||
};
|
};
|
||||||
std::string hash_name = kname;
|
std::string hash_name = kname;
|
||||||
hash_name += has_mask ? "_mask" : "_nomask";
|
hash_name += has_mask ? "_mask" : "_nomask";
|
||||||
|
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -161,10 +164,14 @@ void sdpa_vector(
|
|||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 9);
|
compute_encoder.set_input_array(m, 9);
|
||||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
auto nd = m.ndim();
|
||||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
int32_t kv_seq_stride =
|
||||||
compute_encoder.set_bytes(seq_stride, 10);
|
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||||
compute_encoder.set_bytes(head_stride, 11);
|
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||||
|
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||||
|
compute_encoder.set_bytes(kv_seq_stride, 10);
|
||||||
|
compute_encoder.set_bytes(q_seq_stride, 11);
|
||||||
|
compute_encoder.set_bytes(head_stride, 12);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
@ -198,7 +205,7 @@ void sdpa_vector_2pass(
|
|||||||
auto k_stride = k.strides()[1];
|
auto k_stride = k.strides()[1];
|
||||||
auto v_stride = v.strides()[1];
|
auto v_stride = v.strides()[1];
|
||||||
MTL::Size group_dims(8 * 32, 1, 1);
|
MTL::Size group_dims(8 * 32, 1, 1);
|
||||||
MTL::Size grid_dims(1, B, blocks);
|
MTL::Size grid_dims(B, q.shape(2), blocks);
|
||||||
|
|
||||||
// Allocate the intermediates
|
// Allocate the intermediates
|
||||||
Shape intermediate_shape;
|
Shape intermediate_shape;
|
||||||
@ -219,11 +226,14 @@ void sdpa_vector_2pass(
|
|||||||
d.add_temporary(maxs, s.index);
|
d.add_temporary(maxs, s.index);
|
||||||
|
|
||||||
bool has_mask = mask.has_value();
|
bool has_mask = mask.has_value();
|
||||||
|
bool query_transposed = !q.flags().row_contiguous;
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||||
|
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||||
};
|
};
|
||||||
std::string hash_name = kname;
|
std::string hash_name = kname;
|
||||||
hash_name += has_mask ? "_mask" : "_nomask";
|
hash_name += has_mask ? "_mask" : "_nomask";
|
||||||
|
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -246,10 +256,14 @@ void sdpa_vector_2pass(
|
|||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 11);
|
compute_encoder.set_input_array(m, 11);
|
||||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
auto nd = m.ndim();
|
||||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
int32_t kv_seq_stride =
|
||||||
compute_encoder.set_bytes(seq_stride, 12);
|
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||||
compute_encoder.set_bytes(head_stride, 13);
|
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||||
|
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||||
|
compute_encoder.set_bytes(kv_seq_stride, 12);
|
||||||
|
compute_encoder.set_bytes(q_seq_stride, 13);
|
||||||
|
compute_encoder.set_bytes(head_stride, 14);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
@ -274,7 +288,7 @@ void sdpa_vector_2pass(
|
|||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
group_dims = MTL::Size(1024, 1, 1);
|
group_dims = MTL::Size(1024, 1, 1);
|
||||||
grid_dims = MTL::Size(1, B, 1);
|
grid_dims = MTL::Size(B, q.shape(2), 1);
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,16 +315,23 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
if (!predicate(arr)) {
|
if (!predicate(arr)) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
copies.push_back(arr_copy);
|
copies.push_back(std::move(arr_copy));
|
||||||
return copies.back();
|
return copies.back();
|
||||||
} else {
|
} else {
|
||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Checks if arr is fully row contiguous
|
// Checks if arr is row contiguous or the sequence and head dimension are
|
||||||
auto is_contiguous = [](const array& arr) {
|
// transposed
|
||||||
return arr.flags().row_contiguous;
|
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
|
||||||
|
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns true if the array is row contiguous except the sequence length
|
// Returns true if the array is row contiguous except the sequence length
|
||||||
@ -328,18 +349,30 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) == 1) {
|
if (q_pre.shape(2) <= 8) {
|
||||||
const auto& q = copy_unless(is_contiguous, q_pre);
|
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
||||||
// 1, heads, seq_len, head_dim
|
|
||||||
// mask [1, query_heads, 1, seq_len]
|
|
||||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||||
|
|
||||||
// Donate the query if possible
|
// Donate the query if possible
|
||||||
if (q.is_donatable() && q.size() == o.size()) {
|
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
||||||
|
q.size() == o.size()) {
|
||||||
o.move_shared_buffer(q);
|
o.move_shared_buffer(q);
|
||||||
} else {
|
} else {
|
||||||
|
if (o.shape(2) == 1) {
|
||||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||||
|
} else {
|
||||||
|
auto strides = o.strides();
|
||||||
|
strides[2] = o.shape(1) * o.shape(3);
|
||||||
|
strides[1] = o.shape(3);
|
||||||
|
auto flags = q.flags();
|
||||||
|
flags.row_contiguous = q.shape(1) == 1;
|
||||||
|
o.set_data(
|
||||||
|
allocator::malloc_or_wait(o.nbytes()),
|
||||||
|
o.size(),
|
||||||
|
std::move(strides),
|
||||||
|
flags);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mask =
|
auto mask =
|
||||||
|
@ -715,7 +715,8 @@ array scaled_dot_product_attention(
|
|||||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||||
|
|
||||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||||
|
(query_sequence_length <= k.shape(-2)) &&
|
||||||
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
||||||
stream.device == Device::gpu;
|
stream.device == Device::gpu;
|
||||||
|
|
||||||
|
@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_fast_sdpa_few_query(self):
|
||||||
|
D = 64
|
||||||
|
L = 43
|
||||||
|
Lq = 4
|
||||||
|
Nq = 8
|
||||||
|
Nkv = 1
|
||||||
|
scale = 1.0
|
||||||
|
mx.random.seed(0)
|
||||||
|
q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))
|
||||||
|
q = q.swapaxes(1, 2)
|
||||||
|
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
|
||||||
|
masks = [
|
||||||
|
mx.array(True),
|
||||||
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
|
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||||
|
]
|
||||||
|
for m in masks:
|
||||||
|
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale=scale,
|
||||||
|
mask=m,
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
return
|
||||||
|
L = 4096
|
||||||
|
scale = 1.0
|
||||||
|
mx.random.seed(0)
|
||||||
|
q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))
|
||||||
|
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
|
||||||
|
masks = [
|
||||||
|
mx.array(True),
|
||||||
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
|
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||||
|
]
|
||||||
|
for m in masks:
|
||||||
|
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale=scale,
|
||||||
|
mask=m,
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
@unittest.skip("Different head and value dims is not enabled")
|
@unittest.skip("Different head and value dims is not enabled")
|
||||||
def test_fast_sdpa_vector_value_dims(self):
|
def test_fast_sdpa_vector_value_dims(self):
|
||||||
D = 192
|
D = 192
|
||||||
|
Loading…
Reference in New Issue
Block a user