diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 291f87203..74843a2f4 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -12,7 +12,7 @@ dtype = mx.float16 loops = 10 -def attention(q, k, v): +def attention(q, k, v, mask=None): def _sdpa(q, k, v): B, Hq, L, D = q.shape _, Hk, S, _ = k.shape @@ -20,6 +20,9 @@ def attention(q, k, v): k = k[:, :, None, :, :] v = v[:, :, None, :, :] s = q @ k.transpose(0, 1, 2, 4, 3) + if mask is not None: + m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S) + s = mx.where(m, s, mx.finfo(s.dtype).min) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) o = p @ v return o.reshape(B, Hq, L, D) @@ -29,9 +32,9 @@ def attention(q, k, v): return q -def sdpa(q, k, v): +def sdpa(q, k, v, mask=None): for i in range(loops): - q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) return q @@ -53,6 +56,26 @@ def time_self_attention_sdpa(): time_fn(sdpa, q, k, v) +def time_self_attention_sdpa_with_mask(): + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) + k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + mask = mx.full((L,), True) + mask[L // 2 :] = False + mx.eval(q, k, v, mask) + + def sdpa_mask(*args): + return sdpa(*args, mask=mask) + + def attention_mask(*args): + return attention(*args, mask=mask) + + time_fn(attention_mask, q, k, v) + time_fn(sdpa_mask, q, k, v) + + if __name__ == "__main__": time_self_attention_sdpa() time_self_attention_primitives() + time_self_attention_sdpa_with_mask() diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 8b6af638e..8a7351b7c 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -4,6 +4,8 @@ using namespace metal; +constant bool has_mask [[function_constant(20)]]; + template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -15,6 +17,9 @@ template const constant size_t& k_stride, const constant size_t& v_stride, const constant float& scale, + const device bool* mask [[function_constant(has_mask)]], + const constant int& mask_seq_stride [[function_constant(has_mask)]], + const constant int& mask_head_stride [[function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -39,6 +44,9 @@ template queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; + if (has_mask) { + mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; + } out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -54,34 +62,39 @@ template // For each key for (int i = simd_gid; i < N; i += BN) { - // Read the key - for (int i = 0; i < elem_per_thread; i++) { - k[i] = keys[i]; - } + if (!has_mask || mask[0]) { + // Read the key + for (int j = 0; j < elem_per_thread; j++) { + k[j] = keys[j]; + } - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = simd_sum(score); + // Compute the i-th score + U score = 0; + for (int j = 0; j < elem_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; + // Update the output accumulator + for (int j = 0; j < elem_per_thread; j++) { + o[j] = o[j] * factor + exp_score * values[j]; + } } // Move the pointers to the next kv keys += stride; values += stride; + if (has_mask) { + mask += BN * mask_seq_stride; + } } // Each thread has a partial part of the output so we need to combine them. @@ -126,6 +139,9 @@ template const constant size_t& k_stride, const constant size_t& v_stride, const constant float& scale, + const device bool* mask [[function_constant(has_mask)]], + const constant int& mask_seq_stride [[function_constant(has_mask)]], + const constant int& mask_head_stride [[function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -155,6 +171,10 @@ template values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + simd_lid * elem_per_thread; out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; + if (has_mask) { + mask += head_idx * mask_head_stride + + (block_idx * BN + simd_gid) * mask_seq_stride; + } sums += head_idx * blocks + block_idx; maxs += head_idx * blocks + block_idx; @@ -171,34 +191,39 @@ template // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { - // Read the key - for (int i = 0; i < elem_per_thread; i++) { - k[i] = keys[i]; - } + if (!has_mask || mask[0]) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = simd_sum(score); + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } } // Move the pointers to the next kv keys += blocks * stride; values += blocks * stride; + if (has_mask) { + mask += BN * blocks * mask_seq_stride; + } } // Each thread has a partial part of the output so we need to combine them. diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index db5abbf90..105a1e87a 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,6 +1,5 @@ // Copyright © 2024 Apple Inc. -#include #include #include "mlx/backend/common/compiled.h" @@ -116,7 +115,8 @@ void sdpa_vector( const array& k, const array& v, array& out, - float scale) { + float scale, + const std::optional& mask) { // Set the kernel name std::string kname; kname.reserve(64); @@ -134,9 +134,16 @@ void sdpa_vector( MTL::Size group_dims(1024, 1, 1); MTL::Size grid_dims(1, B, 1); + bool has_mask = mask.has_value(); + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + }; + std::string hash_name = kname; + hash_name += has_mask ? "_mask" : "_nomask"; + // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname); + auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -149,6 +156,14 @@ void sdpa_vector( compute_encoder.set_bytes(k_stride, 6); compute_encoder.set_bytes(v_stride, 7); compute_encoder.set_bytes(scale, 8); + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array(m, 9); + int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; + int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; + compute_encoder.set_bytes(seq_stride, 10); + compute_encoder.set_bytes(head_stride, 11); + } // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -161,7 +176,8 @@ void sdpa_vector_2pass( const array& k, const array& v, array& out, - float scale) { + float scale, + const std::optional& mask) { // Set the kernel name std::string kname; kname.reserve(64); @@ -198,9 +214,17 @@ void sdpa_vector_2pass( d.add_temporary(sums, s.index); d.add_temporary(maxs, s.index); + bool has_mask = mask.has_value(); + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + }; + std::string hash_name = kname; + hash_name += has_mask ? "_mask" : "_nomask"; + // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname); + auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -215,6 +239,14 @@ void sdpa_vector_2pass( compute_encoder.set_bytes(k_stride, 8); compute_encoder.set_bytes(v_stride, 9); compute_encoder.set_bytes(scale, 10); + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array(m, 11); + int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; + int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; + compute_encoder.set_bytes(seq_stride, 12); + compute_encoder.set_bytes(head_stride, 13); + } // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -247,8 +279,6 @@ void sdpa_vector_2pass( void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { - assert(inputs.size() == 3); - auto& s = stream(); auto& d = metal::device(s.device); @@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) == 1) { const auto& q = copy_unless(is_contiguous, q_pre); + // 1, heads, seq_len, head_dim + // mask [1, query_heads, 1, seq_len] const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); @@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu( o.set_data(allocator::malloc_or_wait(o.nbytes())); } + auto mask = + inputs.size() > 3 ? std::optional{inputs[3]} : std::nullopt; + // We route to the 2 pass fused attention if // - The device is large and the sequence length long // - The sequence length is even longer and we have gqa char devc = d.get_architecture().back(); if ((devc == 'd' && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_); + sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask); } else { - sdpa_vector(s, d, q, k, v, o, scale_); + sdpa_vector(s, d, q, k, v, o, scale_, mask); } } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 7b3570b90..79f73aee6 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -609,27 +609,32 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (mask && promote_types((*mask).dtype(), final_type) != final_type) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask type must promote to output type. " - << final_type << "."; - throw std::invalid_argument(msg.str()); + if (mask) { + // Check type + if (promote_types(mask->dtype(), final_type) != final_type) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Mask type must promote to output type. " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + // Check shape + auto mask_shape = queries.shape(); + mask_shape.back() = keys.shape(-2); + if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape() + << " does not broadcast to implicit scores with shape " << mask_shape + << "."; + throw std::invalid_argument(msg.str()); + } } auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); - /* generic implementation for use cases that Metal implementation does not - * support. For non-supported cases listed below, use MLX primitives: - * * CPU implementation - * * batch size > 1 for decoding or causal attention - * * query sequence length > 1 for decoding - * * query sequence length > 16 && non-null mask (causal attention) - * * non-null mask - * * dtype is not fp32 or fp16 - */ - + /* Generic implementation for use cases that Metal implementation does not + * support. */ int threshold = 32; // TODO: Fix after dev if (memory_efficient_threshold.has_value()) { threshold = std::max(1, memory_efficient_threshold.value()); @@ -690,27 +695,27 @@ array scaled_dot_product_attention( !mask.has_value() && sdpa_full_supported_head_dim && stream.device == Device::gpu; + const bool supported_mask = !mask || (mask->dtype() == bool_); const bool supports_sdpa_vector = query_sequence_length == 1 && - !mask.has_value() && sdpa_vector_supported_head_dim && + supported_mask && sdpa_vector_supported_head_dim && stream.device == Device::gpu; implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; + std::vector inputs = {q, k, v}; + if (mask) { + inputs.push_back(*mask); + } if (implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), final_type, std::make_shared(stream, fallback, scale), - {q, k, v}); - } - - if (mask.has_value()) { - return fallback({q, k, v, mask.value()})[0]; - } else { - return fallback({q, k, v})[0]; + std::move(inputs)); } + return fallback(inputs)[0]; } bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 3b86ef17d..f1298cb35 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -10,7 +10,10 @@ import numpy as np def mlx_primitives_sdpa(q, k, v, scale, mask=None): p = (q * scale) @ k.transpose(0, 1, 3, 2) if mask is not None: - p += mask + if mask.dtype == mx.bool_: + p = mx.where(mask, p, mx.finfo(mx.float32).min) + else: + p += mask scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype) return scores @ v @@ -198,6 +201,67 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + def test_fast_sdpa_vector(self): + D = 64 + L = 43 + Nq = 4 + Nkv = 1 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + + with self.assertRaises(ValueError): + mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=mx.full((Nq, 2, L), False), + ) + + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + L = 4096 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + if __name__ == "__main__": unittest.main(failfast=True)