mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
This commit is contained in:
parent
3c164fca8c
commit
9adcd1a650
@ -28,11 +28,34 @@ def bench(f, *args):
|
|||||||
return (e - s) * 1e-9
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||||
|
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask == "additive":
|
||||||
|
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
elif mask == "bool":
|
||||||
|
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
|
||||||
|
return q_mx, k_mx, v_mx, scale, mask
|
||||||
|
|
||||||
|
|
||||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||||
q_dtype = q.dtype
|
q_dtype = q.dtype
|
||||||
q = q * mx.array(scale, q_dtype)
|
q = q * mx.array(scale, q_dtype)
|
||||||
n_q_heads = q.shape[-3]
|
n_q_heads = q.shape[-3]
|
||||||
@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
|
|
||||||
B = q.shape[0]
|
B = q.shape[0]
|
||||||
L = q.shape[2]
|
L = q.shape[2]
|
||||||
|
kL = k.shape[2]
|
||||||
|
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
v = mx.expand_dims(v, 2)
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
scores = q @ mx.swapaxes(k, -1, -2)
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
if f32softmax:
|
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
if mask is not None:
|
||||||
|
|
||||||
|
if mask == "causal":
|
||||||
|
q_offset = max(0, kL - L)
|
||||||
|
q_indices = mx.arange(q_offset, q_offset + L)
|
||||||
|
k_indices = mx.arange(kL)
|
||||||
|
mask = q_indices[:, None] >= k_indices[None]
|
||||||
|
|
||||||
|
if n_repeats > 1 and mask.ndim >= 3:
|
||||||
|
if mask.shape[-3] == 1:
|
||||||
|
mask = mx.expand_dims(mask, -3)
|
||||||
else:
|
else:
|
||||||
scores = mx.softmax(scores, axis=-1)
|
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||||
|
|
||||||
|
if mask.dtype == mx.bool_:
|
||||||
|
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||||
|
else:
|
||||||
|
scores += mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
out = scores @ v
|
out = scores @ v
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
def mlx_fused_attn(q, k, v, scale, mask):
|
||||||
q_out = q
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
if transpose:
|
if transpose:
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||||
|
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||||
|
else:
|
||||||
|
return f(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
q_out = q
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
for i in range(N_iter_func):
|
||||||
if transpose:
|
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
mx.eval(q_out)
|
mx.eval(q_out)
|
||||||
return q_out
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
def bench_shape(
|
||||||
q_out = q
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||||
if transpose:
|
):
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
mx.eval(q_out)
|
|
||||||
return q_out
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
|
||||||
shape_q = (
|
|
||||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
|
||||||
)
|
|
||||||
shape_kv = (
|
|
||||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
time_mlx_unfused = bench(
|
||||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
)
|
||||||
|
time_mlx_fused = bench(
|
||||||
|
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
scale = math.sqrt(1.0 / head_dim)
|
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||||
|
o_mlx_unfused = do_attention(
|
||||||
|
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
q_mx = mx.array(q_np)
|
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||||
k_mx = mx.array(k_np)
|
|
||||||
v_mx = mx.array(v_np)
|
|
||||||
|
|
||||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
|
||||||
|
|
||||||
if transpose:
|
|
||||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
|
||||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
|
||||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
|
||||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
|
||||||
print(
|
print(
|
||||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return time_mlx_fused, time_mlx_unfused
|
return time_mlx_fused, time_mlx_unfused
|
||||||
@ -151,39 +173,51 @@ if __name__ == "__main__":
|
|||||||
( 1, 128, 128, 64, 32, 32),
|
( 1, 128, 128, 64, 32, 32),
|
||||||
( 1, 256, 256, 64, 32, 32),
|
( 1, 256, 256, 64, 32, 32),
|
||||||
( 1, 512, 512, 64, 32, 32),
|
( 1, 512, 512, 64, 32, 32),
|
||||||
( 1, 1024, 1024, 64, 32, 32),
|
( 1, 1024, 1024, 64, 32, 8),
|
||||||
( 1, 2048, 2048, 64, 32, 32),
|
( 1, 2048, 2048, 64, 32, 8),
|
||||||
( 1, 4096, 4096, 64, 32, 32),
|
( 1, 4096, 4096, 64, 32, 8),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_80 = (
|
shapes_80 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 80, 32, 32),
|
( 1, 1024, 1024, 80, 32, 8),
|
||||||
( 1, 2048, 2048, 80, 32, 32),
|
( 1, 2048, 2048, 80, 32, 8),
|
||||||
( 1, 4096, 4096, 80, 32, 32),
|
( 1, 4096, 4096, 80, 32, 8),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_128 = (
|
shapes_128 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 128, 32, 32),
|
( 1, 1024, 1024, 128, 32, 8),
|
||||||
( 1, 2048, 2048, 128, 32, 32),
|
( 1, 2048, 2048, 128, 32, 8),
|
||||||
( 1, 4096, 4096, 128, 32, 32),
|
( 1, 4096, 4096, 128, 32, 8),
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
shapes = shapes_64 + shapes_80 + shapes_128
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
masks = [None, "bool", "causal"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||||
|
)
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
for transpose in transposes:
|
for transpose in transposes:
|
||||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
np_dtype = getattr(np, dtype)
|
for mask_in in masks:
|
||||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
B,
|
||||||
|
qsl,
|
||||||
|
ksl,
|
||||||
|
head_dim,
|
||||||
|
n_q_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
dtype,
|
||||||
|
transpose,
|
||||||
|
mask_in,
|
||||||
)
|
)
|
||||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
t_str = 1 if transpose else 0
|
t_str = 1 if transpose else 0
|
||||||
print(
|
print(
|
||||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024-25 Apple Inc.
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
@ -9,6 +9,9 @@ using namespace mlx::steel;
|
|||||||
constant bool align_Q [[function_constant(200)]];
|
constant bool align_Q [[function_constant(200)]];
|
||||||
constant bool align_K [[function_constant(201)]];
|
constant bool align_K [[function_constant(201)]];
|
||||||
|
|
||||||
|
constant bool has_mask [[function_constant(300)]];
|
||||||
|
constant bool do_causal [[function_constant(301)]];
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TransformScale {
|
struct TransformScale {
|
||||||
T scale;
|
T scale;
|
||||||
@ -69,6 +72,7 @@ template <
|
|||||||
int BD,
|
int BD,
|
||||||
int WM,
|
int WM,
|
||||||
int WN,
|
int WN,
|
||||||
|
typename MaskType = float,
|
||||||
typename AccumType = float>
|
typename AccumType = float>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
||||||
const device T* Q [[buffer(0)]],
|
const device T* Q [[buffer(0)]],
|
||||||
@ -76,6 +80,8 @@ template <
|
|||||||
const device T* V [[buffer(2)]],
|
const device T* V [[buffer(2)]],
|
||||||
device T* O [[buffer(3)]],
|
device T* O [[buffer(3)]],
|
||||||
const constant AttnParams* params [[buffer(4)]],
|
const constant AttnParams* params [[buffer(4)]],
|
||||||
|
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||||
|
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -102,6 +108,11 @@ template <
|
|||||||
tidl.y * params->O_strides[1] + // Head
|
tidl.y * params->O_strides[1] + // Head
|
||||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||||
|
|
||||||
|
if (has_mask) {
|
||||||
|
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||||
|
tidl.y * mask_params->M_strides[1]; // Head
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare threadgroup memory
|
// Prepare threadgroup memory
|
||||||
constexpr short padQ = 16 / sizeof(T);
|
constexpr short padQ = 16 / sizeof(T);
|
||||||
constexpr short padK = 16 / sizeof(T);
|
constexpr short padK = 16 / sizeof(T);
|
||||||
@ -203,7 +214,7 @@ template <
|
|||||||
|
|
||||||
// Load Q blocks apply scale
|
// Load Q blocks apply scale
|
||||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||||
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
|
loader_q.load_safe(short2(BD, params->qL_rem));
|
||||||
} else {
|
} else {
|
||||||
loader_q.load_unsafe();
|
loader_q.load_unsafe();
|
||||||
}
|
}
|
||||||
@ -221,12 +232,19 @@ template <
|
|||||||
max_score[i] = Limits<AccumType>::min;
|
max_score[i] = Limits<AccumType>::min;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int kb_lim = params->NK;
|
||||||
|
|
||||||
|
if (do_causal) {
|
||||||
|
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||||
|
kb_lim = (q_max + BK - 1) / BK;
|
||||||
|
}
|
||||||
|
|
||||||
// Loop over KV seq length
|
// Loop over KV seq length
|
||||||
for (int kb = 0; kb < params->NK; kb++) {
|
for (int kb = 0; kb < kb_lim; kb++) {
|
||||||
// Load K block and apply scale
|
// Load K block and apply scale
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (!align_K && kb == (params->NK_aligned)) {
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
loader_k.load_safe(short2(BD, params->kL_rem));
|
||||||
} else {
|
} else {
|
||||||
loader_k.load_unsafe();
|
loader_k.load_unsafe();
|
||||||
}
|
}
|
||||||
@ -250,12 +268,11 @@ template <
|
|||||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask out of length sequence
|
// Mask out length sequence
|
||||||
if (!align_K && kb == (params->NK_aligned)) {
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
using stile_t = decltype(Stile);
|
using stile_t = decltype(Stile);
|
||||||
using selem_t = typename stile_t::elem_type;
|
using selem_t = typename stile_t::elem_type;
|
||||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||||
const short lim = params->kL - params->NK_aligned * BK;
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||||
@ -264,7 +281,7 @@ template <
|
|||||||
short col_pos = sn + (j * stile_t::kFragCols);
|
short col_pos = sn + (j * stile_t::kFragCols);
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||||
if ((col_pos + jj) >= lim) {
|
if ((col_pos + jj) >= params->kL_rem) {
|
||||||
Stile.frag_at(i, j)[jj] = neg_inf;
|
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -272,11 +289,78 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mask out if causal
|
||||||
|
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
|
||||||
|
using stile_t = decltype(Stile);
|
||||||
|
using selem_t = typename stile_t::elem_type;
|
||||||
|
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||||
|
const int row_pos =
|
||||||
|
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||||
|
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||||
|
if (row_pos < (col_pos + jj)) {
|
||||||
|
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other masking as needed
|
||||||
|
if (has_mask) {
|
||||||
|
using stile_t = decltype(Stile);
|
||||||
|
using selem_t = typename stile_t::elem_type;
|
||||||
|
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||||
|
|
||||||
|
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||||
|
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
||||||
|
|
||||||
|
using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
|
||||||
|
using frag_t = typename MMAFrag_mask_t::frag_type;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||||
|
const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||||
|
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
|
||||||
|
|
||||||
|
frag_t mfrag;
|
||||||
|
|
||||||
|
MMAFrag_mask_t::load_safe(
|
||||||
|
mfrag,
|
||||||
|
mask,
|
||||||
|
int(mask_params->M_strides[2]),
|
||||||
|
Int<1>{},
|
||||||
|
params->qL,
|
||||||
|
params->kL,
|
||||||
|
row_pos,
|
||||||
|
col_pos);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
|
||||||
|
if constexpr (is_bool) {
|
||||||
|
Stile.frag_at(i, j)[jj] =
|
||||||
|
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
||||||
|
} else {
|
||||||
|
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Load V blocks
|
// Load V blocks
|
||||||
if (!align_K && kb == (params->NK_aligned)) {
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
loader_v.load_safe(short2(BD, params->kL_rem));
|
||||||
} else {
|
} else {
|
||||||
loader_v.load_unsafe();
|
loader_v.load_unsafe();
|
||||||
}
|
}
|
||||||
@ -367,8 +451,7 @@ template <
|
|||||||
O += (tm + sm) * params->O_strides[2] + sn;
|
O += (tm + sm) * params->O_strides[2] + sn;
|
||||||
|
|
||||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||||
auto dst_tile_dims =
|
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
|
||||||
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
|
|
||||||
|
|
||||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||||
return;
|
return;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024-25 Apple Inc.
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
@ -6,26 +6,23 @@
|
|||||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
|
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
|
||||||
|
|
||||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
|
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
|
||||||
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
|
instantiate_kernel( \
|
||||||
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
|
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||||
const device dtype* Q [[buffer(0)]], \
|
"_wm" #wm "_wn" #wn "_mask" #mname, \
|
||||||
const device dtype* K [[buffer(1)]], \
|
attention, dtype, bq, bk, bd, wm, wn, mtype, float)
|
||||||
const device dtype* V [[buffer(2)]], \
|
|
||||||
device dtype* O [[buffer(3)]],\
|
|
||||||
const constant AttnParams* params [[buffer(4)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
|
||||||
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
|
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
|
||||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
|
||||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
|
||||||
|
|
||||||
instantiate_attn_shapes_helper(float16, half);
|
#define instantiate_attn_mask_helper(iname, itype) \
|
||||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
|
||||||
|
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
|
||||||
|
|
||||||
instantiate_attn_shapes_helper(float32, float);
|
instantiate_attn_mask_helper(float16, half);
|
||||||
|
instantiate_attn_mask_helper(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
instantiate_attn_mask_helper(float32, float);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
@ -111,7 +111,7 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
for (short j = 0; j < kElemCols; j++) {
|
for (short j = 0; j < kElemCols; j++) {
|
||||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||||
dst[i * kElemCols + j] =
|
dst[i * kElemCols + j] =
|
||||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
|
||||||
} else {
|
} else {
|
||||||
dst[i * kElemCols + j] = T(0);
|
dst[i * kElemCols + j] = T(0);
|
||||||
}
|
}
|
||||||
|
@ -26,11 +26,19 @@ struct AttnParams {
|
|||||||
int NQ_aligned; ///< Number of full query blocks
|
int NQ_aligned; ///< Number of full query blocks
|
||||||
int NK_aligned; ///< Number of full key/value blocks
|
int NK_aligned; ///< Number of full key/value blocks
|
||||||
|
|
||||||
|
int qL_rem; ///< Remainder in last query block
|
||||||
|
int kL_rem; ///< Remainder in last key/value block
|
||||||
|
int qL_off; ///< Offset in query sequence start
|
||||||
|
|
||||||
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||||
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||||
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||||
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct AttnMaskParams {
|
||||||
|
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace steel
|
} // namespace steel
|
||||||
} // namespace mlx
|
} // namespace mlx
|
||||||
|
@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
const float scale,
|
const float scale,
|
||||||
array& o) {
|
array& o,
|
||||||
|
bool do_causal_ = false,
|
||||||
|
const std::optional<array>& mask = std::nullopt) {
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
int wm = 4;
|
int wm = 4;
|
||||||
@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(
|
|||||||
|
|
||||||
const bool align_Q = (qL % bq) == 0;
|
const bool align_Q = (qL % bq) == 0;
|
||||||
const bool align_K = (kL % bk) == 0;
|
const bool align_K = (kL % bk) == 0;
|
||||||
|
const bool has_mask = !!mask;
|
||||||
|
const bool do_causal = do_causal_;
|
||||||
|
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||||
};
|
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||||
|
{&do_causal, MTL::DataType::DataTypeBool, 301}};
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
|
|||||||
<< "_bq" << bq
|
<< "_bq" << bq
|
||||||
<< "_bk" << bk
|
<< "_bk" << bk
|
||||||
<< "_bd" << bd
|
<< "_bd" << bd
|
||||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
<< "_wm" << wm
|
||||||
|
<< "_wn" << wn
|
||||||
|
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||||
|
|
||||||
std::string base_name = kname.str();
|
std::string base_name = kname.str();
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||||
|
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||||
|
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
|
|
||||||
@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
|
|||||||
/* int NQ_aligned = */ NQ_aligned,
|
/* int NQ_aligned = */ NQ_aligned,
|
||||||
/* int NK_aligned = */ NK_aligned,
|
/* int NK_aligned = */ NK_aligned,
|
||||||
|
|
||||||
|
/* int qL_rem = */ (qL - NQ_aligned * bq),
|
||||||
|
/* int kL_rem = */ (kL - NK_aligned * bk),
|
||||||
|
/* int qL_off = */ (kL - qL),
|
||||||
|
|
||||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
|
|||||||
compute_encoder.set_output_array(o, 3);
|
compute_encoder.set_output_array(o, 3);
|
||||||
compute_encoder.set_bytes(params, 4);
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
|
||||||
|
if (mask) {
|
||||||
|
auto m = *mask;
|
||||||
|
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||||
|
m.strides(0), m.strides(1), m.strides(2)}};
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(mask_params, 5);
|
||||||
|
compute_encoder.set_input_array(m, 6);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||||
|
|
||||||
@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
// Checks that the headdim dimension has stride 1.
|
// Checks that the headdim dimension has stride 1.
|
||||||
auto is_matrix_contiguous = [](const array& arr) {
|
auto is_matrix_contiguous = [](const array& arr) {
|
||||||
return arr.strides(3) == 1;
|
return arr.strides(-1) == 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
{str_oB, str_oH, str_oL, str_oD},
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
flags);
|
flags);
|
||||||
|
|
||||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
auto mask = inputs.size() > 3
|
||||||
|
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
|
||||||
|
: std::nullopt;
|
||||||
|
|
||||||
|
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
83
mlx/fast.cpp
83
mlx/fast.cpp
@ -567,7 +567,7 @@ array scaled_dot_product_attention(
|
|||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::optional<array>& mask,
|
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
|
||||||
const std::optional<int> memory_efficient_threshold,
|
const std::optional<int> memory_efficient_threshold,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
for (const auto& tensor : {queries, keys, values}) {
|
for (const auto& tensor : {queries, keys, values}) {
|
||||||
@ -578,10 +578,29 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (mask && (*mask).ndim() > 4) {
|
|
||||||
|
bool do_causal = false;
|
||||||
|
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||||
|
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
|
||||||
|
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
|
||||||
|
bool has_bool_mask = false;
|
||||||
|
|
||||||
|
if (has_str_mask) {
|
||||||
|
if (std::get<std::string>(mask) != "causal") {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||||
|
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
} else {
|
||||||
|
do_causal = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||||
<< (*mask).shape() << " expected to have at most rank 4";
|
<< (std::get<array>(mask)).shape()
|
||||||
|
<< " expected to have at most rank 4";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,9 +650,11 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mask) {
|
if (has_arr_mask) {
|
||||||
// Check type
|
// Check type
|
||||||
if (promote_types(mask->dtype(), final_type) != final_type) {
|
auto mask_arr = std::get<array>(mask);
|
||||||
|
has_bool_mask = mask_arr.dtype() == bool_;
|
||||||
|
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||||
<< final_type << ".";
|
<< final_type << ".";
|
||||||
@ -642,9 +663,10 @@ array scaled_dot_product_attention(
|
|||||||
// Check shape
|
// Check shape
|
||||||
auto mask_shape = queries.shape();
|
auto mask_shape = queries.shape();
|
||||||
mask_shape.back() = keys.shape(-2);
|
mask_shape.back() = keys.shape(-2);
|
||||||
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
|
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
|
msg << "[scaled_dot_product_attention] Mask with shape "
|
||||||
|
<< mask_arr.shape()
|
||||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
<< " does not broadcast to implicit scores with shape " << mask_shape
|
||||||
<< ".";
|
<< ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
@ -662,7 +684,7 @@ array scaled_dot_product_attention(
|
|||||||
threshold = std::max(1, memory_efficient_threshold.value());
|
threshold = std::max(1, memory_efficient_threshold.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
|
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs) {
|
||||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||||
int n_repeats = n_q_heads / n_kv_heads;
|
int n_repeats = n_q_heads / n_kv_heads;
|
||||||
@ -676,9 +698,21 @@ array scaled_dot_product_attention(
|
|||||||
v = expand_dims(v, 2, s);
|
v = expand_dims(v, 2, s);
|
||||||
}
|
}
|
||||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||||
if (inputs.size() > 3) {
|
if (inputs.size() > 3 || do_causal) {
|
||||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||||
auto mask = inputs[3];
|
auto mask = inputs.back();
|
||||||
|
|
||||||
|
if (do_causal) {
|
||||||
|
int kL = k.shape(-2);
|
||||||
|
int qL = q.shape(-2);
|
||||||
|
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
|
||||||
|
auto q_idx = arange(q_off, q_off + qL, s);
|
||||||
|
auto k_idx = arange(0, kL, s);
|
||||||
|
q_idx = expand_dims(q_idx, 1, s);
|
||||||
|
k_idx = expand_dims(k_idx, 0, s);
|
||||||
|
mask = greater_equal(q_idx, k_idx, s);
|
||||||
|
}
|
||||||
|
|
||||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||||
if (mask.shape(-3) == 1) {
|
if (mask.shape(-3) == 1) {
|
||||||
mask = expand_dims(mask, -3, s);
|
mask = expand_dims(mask, -3, s);
|
||||||
@ -702,9 +736,10 @@ array scaled_dot_product_attention(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
const size_t value_head_dim = v.shape(-1);
|
const int value_head_dim = v.shape(-1);
|
||||||
const size_t query_head_dim = q.shape(-1);
|
const int query_head_dim = q.shape(-1);
|
||||||
const size_t query_sequence_length = q.shape(2);
|
const int query_sequence_length = q.shape(2);
|
||||||
|
const int key_sequence_length = k.shape(2);
|
||||||
|
|
||||||
const bool sdpa_vector_supported_head_dim =
|
const bool sdpa_vector_supported_head_dim =
|
||||||
query_head_dim == value_head_dim &&
|
query_head_dim == value_head_dim &&
|
||||||
@ -712,27 +747,33 @@ array scaled_dot_product_attention(
|
|||||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||||
|
(query_sequence_length <= key_sequence_length && do_causal);
|
||||||
|
|
||||||
|
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||||
|
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||||
|
stream.device == Device::gpu;
|
||||||
|
|
||||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||||
(query_sequence_length <= k.shape(-2)) &&
|
(query_sequence_length <= key_sequence_length) &&
|
||||||
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
|
||||||
stream.device == Device::gpu;
|
stream.device == Device::gpu;
|
||||||
|
|
||||||
const bool implementation_supports_use_case =
|
const bool implementation_supports_use_case =
|
||||||
supports_sdpa_full || supports_sdpa_vector;
|
supports_sdpa_full || supports_sdpa_vector;
|
||||||
|
|
||||||
std::vector<array> inputs = {q, k, v};
|
std::vector<array> inputs = {q, k, v};
|
||||||
if (mask) {
|
if (has_arr_mask) {
|
||||||
inputs.push_back(*mask);
|
inputs.push_back(std::get<array>(mask));
|
||||||
}
|
}
|
||||||
if (implementation_supports_use_case) {
|
if (implementation_supports_use_case) {
|
||||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
final_type,
|
final_type,
|
||||||
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
std::make_shared<ScaledDotProductAttention>(
|
||||||
|
stream, fallback, scale, do_causal),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
return fallback(inputs)[0];
|
return fallback(inputs)[0];
|
||||||
@ -741,7 +782,7 @@ array scaled_dot_product_attention(
|
|||||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||||
const ScaledDotProductAttention& a_other =
|
const ScaledDotProductAttention& a_other =
|
||||||
static_cast<const ScaledDotProductAttention&>(other);
|
static_cast<const ScaledDotProductAttention&>(other);
|
||||||
return scale_ == a_other.scale_;
|
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||||
}
|
}
|
||||||
|
|
||||||
array pack_and_quantize(
|
array pack_and_quantize(
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -47,7 +48,7 @@ array scaled_dot_product_attention(
|
|||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::optional<array>& mask = std::nullopt,
|
const std::variant<std::monostate, std::string, array>& mask = {},
|
||||||
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -206,8 +206,9 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
explicit ScaledDotProductAttention(
|
explicit ScaledDotProductAttention(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
const float scale)
|
const float scale,
|
||||||
: Custom(stream, fallback), scale_(scale) {}
|
const bool do_causal)
|
||||||
|
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -225,12 +226,13 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
DEFINE_PRINT(ScaledDotProductAttention);
|
DEFINE_PRINT(ScaledDotProductAttention);
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_pair(nullptr, scale_);
|
return std::make_tuple(nullptr, scale_, do_causal_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
float scale_;
|
float scale_;
|
||||||
|
bool do_causal_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AffineQuantize : public Custom {
|
class AffineQuantize : public Custom {
|
||||||
|
@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
"memory_efficient_threshold"_a = nb::none(),
|
"memory_efficient_threshold"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||||
|
|
||||||
@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||||
mask (array, optional): A boolean or additive mask to apply to the
|
mask (Union[None, str, array], optional): A causal, boolean or additive
|
||||||
query-key scores. The mask can have at most 4 dimensions and must
|
mask to apply to the query-key scores. The mask can have at most 4
|
||||||
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
|
dimensions and must be broadcast-compatible with the shape
|
||||||
additive mask is given its type must promote to the promoted
|
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
|
||||||
type of ``q``, ``k``, and ``v``.
|
promote to the promoted type of ``q``, ``k``, and ``v``.
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array.
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
@ -6,6 +6,91 @@ import mlx_tests
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||||
|
q_dtype = q.dtype
|
||||||
|
q = q * mx.array(scale, q_dtype)
|
||||||
|
n_q_heads = q.shape[-3]
|
||||||
|
n_kv_heads = k.shape[-3]
|
||||||
|
n_repeats = n_q_heads // n_kv_heads
|
||||||
|
|
||||||
|
B = q.shape[0]
|
||||||
|
L = q.shape[2]
|
||||||
|
kL = k.shape[2]
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
|
k = mx.expand_dims(k, 2)
|
||||||
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
|
||||||
|
if mask == "causal":
|
||||||
|
q_offset = max(0, kL - L)
|
||||||
|
q_indices = mx.arange(q_offset, q_offset + L)
|
||||||
|
k_indices = mx.arange(kL)
|
||||||
|
mask = q_indices[:, None] >= k_indices[None]
|
||||||
|
|
||||||
|
if n_repeats > 1 and mask.ndim >= 3:
|
||||||
|
if mask.shape[-3] == 1:
|
||||||
|
mask = mx.expand_dims(mask, -3)
|
||||||
|
else:
|
||||||
|
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||||
|
|
||||||
|
if mask.dtype == mx.bool_:
|
||||||
|
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||||
|
else:
|
||||||
|
scores += mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
|
out = scores @ v
|
||||||
|
if n_repeats > 1:
|
||||||
|
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
if transpose:
|
||||||
|
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||||
|
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||||
|
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||||
|
else:
|
||||||
|
return f(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||||
|
np.random.seed(0)
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||||
|
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask == "additive":
|
||||||
|
mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype)
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
elif mask == "bool":
|
||||||
|
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
|
||||||
|
return q_mx, k_mx, v_mx, scale, mask
|
||||||
|
|
||||||
|
|
||||||
# SDPA for MHA (n_heads == n_kv_heads)
|
# SDPA for MHA (n_heads == n_kv_heads)
|
||||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||||
@ -365,5 +450,84 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSDPA(mlx_tests.MLXTestCase):
|
||||||
|
@property
|
||||||
|
def dtypes(self):
|
||||||
|
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||||
|
|
||||||
|
def test_sdpa(self):
|
||||||
|
if not mx.metal.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
shapes_64 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 128, 128, 64, 32, 32),
|
||||||
|
( 1, 64, 128, 64, 32, 32),
|
||||||
|
( 1, 65, 128, 64, 32, 8),
|
||||||
|
( 1, 64, 127, 64, 32, 8),
|
||||||
|
( 1, 65, 127, 64, 32, 8),
|
||||||
|
( 1, 127, 65, 64, 32, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_128 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 128, 128, 128, 32, 8),
|
||||||
|
( 1, 64, 128, 128, 32, 8),
|
||||||
|
( 1, 65, 127, 128, 32, 8),
|
||||||
|
( 1, 127, 65, 128, 32, 8),
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
shapes = shapes_64 + shapes_128
|
||||||
|
masks = [None, "additive", "bool", "causal"]
|
||||||
|
transposes = (False, True)
|
||||||
|
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
for t in transposes:
|
||||||
|
for mask_str in masks:
|
||||||
|
for B, qL, kL, D, qH, kH in shapes:
|
||||||
|
with self.subTest(
|
||||||
|
B=B,
|
||||||
|
qsl=qL,
|
||||||
|
ksl=kL,
|
||||||
|
head_dim=D,
|
||||||
|
n_q_heads=qH,
|
||||||
|
n_kv_heads=kH,
|
||||||
|
mask=mask_str,
|
||||||
|
transpose=t,
|
||||||
|
dtype=dtype,
|
||||||
|
):
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||||
|
B, qL, kL, D, qH, kH, mask_str, t, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
out_ref = do_attention(
|
||||||
|
mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t
|
||||||
|
)
|
||||||
|
|
||||||
|
out_fst = do_attention(
|
||||||
|
mx.fast.scaled_dot_product_attention,
|
||||||
|
q_mx,
|
||||||
|
k_mx,
|
||||||
|
v_mx,
|
||||||
|
scale,
|
||||||
|
mask,
|
||||||
|
t,
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 2e-5 if dtype == "float32" else 3e-4
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(out_ref.shape), list(out_fst.shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(out_fst, out_ref, atol=atol, rtol=atol)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user