mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
parent
516ded618b
commit
d1766f2c70
@ -12,7 +12,7 @@ dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
def attention(q, k, v, mask=None):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
@ -20,6 +20,9 @@ def attention(q, k, v):
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
@ -29,9 +32,9 @@ def attention(q, k, v):
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
def sdpa(q, k, v, mask=None):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
return q
|
||||
|
||||
|
||||
@ -53,6 +56,26 @@ def time_self_attention_sdpa():
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_mask [[function_constant(20)]];
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
const device T* queries [[buffer(0)]],
|
||||
@ -15,6 +17,9 @@ template <typename T, int D>
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -39,6 +44,9 @@ template <typename T, int D>
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
||||
}
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
@ -54,34 +62,39 @@ template <typename T, int D>
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
if (!has_mask || mask[0]) {
|
||||
// Read the key
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
k[j] = keys[j];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
// Update the output accumulator
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
o[j] = o[j] * factor + exp_score * values[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += stride;
|
||||
values += stride;
|
||||
if (has_mask) {
|
||||
mask += BN * mask_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
@ -126,6 +139,9 @@ template <typename T, int D>
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -155,6 +171,10 @@ template <typename T, int D>
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_seq_stride;
|
||||
}
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
|
||||
@ -171,34 +191,39 @@ template <typename T, int D>
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
if (!has_mask || mask[0]) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * stride;
|
||||
values += blocks * stride;
|
||||
if (has_mask) {
|
||||
mask += BN * blocks * mask_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
@ -1,6 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@ -116,7 +115,8 @@ void sdpa_vector(
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
const std::optional<array>& mask) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@ -134,9 +134,16 @@ void sdpa_vector(
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@ -149,6 +156,14 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(k_stride, 6);
|
||||
compute_encoder.set_bytes(v_stride, 7);
|
||||
compute_encoder.set_bytes(scale, 8);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 9);
|
||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
||||
compute_encoder.set_bytes(seq_stride, 10);
|
||||
compute_encoder.set_bytes(head_stride, 11);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@ -161,7 +176,8 @@ void sdpa_vector_2pass(
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
const std::optional<array>& mask) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@ -198,9 +214,17 @@ void sdpa_vector_2pass(
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@ -215,6 +239,14 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(k_stride, 8);
|
||||
compute_encoder.set_bytes(v_stride, 9);
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
||||
compute_encoder.set_bytes(seq_stride, 12);
|
||||
compute_encoder.set_bytes(head_stride, 13);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@ -247,8 +279,6 @@ void sdpa_vector_2pass(
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
const auto& q = copy_unless(is_contiguous, q_pre);
|
||||
// 1, heads, seq_len, head_dim
|
||||
// mask [1, query_heads, 1, seq_len]
|
||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
auto mask =
|
||||
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
// - The sequence length is even longer and we have gqa
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, mask);
|
||||
}
|
||||
}
|
||||
|
||||
|
51
mlx/fast.cpp
51
mlx/fast.cpp
@ -609,27 +609,32 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mask && promote_types((*mask).dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
if (mask) {
|
||||
// Check type
|
||||
if (promote_types(mask->dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Check shape
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
|
||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
/* generic implementation for use cases that Metal implementation does not
|
||||
* support. For non-supported cases listed below, use MLX primitives:
|
||||
* * CPU implementation
|
||||
* * batch size > 1 for decoding or causal attention
|
||||
* * query sequence length > 1 for decoding
|
||||
* * query sequence length > 16 && non-null mask (causal attention)
|
||||
* * non-null mask
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
/* Generic implementation for use cases that Metal implementation does not
|
||||
* support. */
|
||||
int threshold = 32; // TODO: Fix after dev
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
@ -690,27 +695,27 @@ array scaled_dot_product_attention(
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
const bool supported_mask = !mask || (mask->dtype() == bool_);
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
!mask.has_value() && sdpa_vector_supported_head_dim &&
|
||||
supported_mask && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
implementation_supports_use_case &=
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (mask) {
|
||||
inputs.push_back(*mask);
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
||||
{q, k, v});
|
||||
}
|
||||
|
||||
if (mask.has_value()) {
|
||||
return fallback({q, k, v, mask.value()})[0];
|
||||
} else {
|
||||
return fallback({q, k, v})[0];
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(inputs)[0];
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
|
@ -10,7 +10,10 @@ import numpy as np
|
||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
p += mask
|
||||
if mask.dtype == mx.bool_:
|
||||
p = mx.where(mask, p, mx.finfo(mx.float32).min)
|
||||
else:
|
||||
p += mask
|
||||
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
|
||||
return scores @ v
|
||||
|
||||
@ -198,6 +201,67 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
def test_fast_sdpa_vector(self):
|
||||
D = 64
|
||||
L = 43
|
||||
Nq = 4
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=mx.full((Nq, 2, L), False),
|
||||
)
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
L = 4096
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Loading…
Reference in New Issue
Block a user