Add boolean mask support in vector SDPA (#1757)

This commit is contained in:
Awni Hannun
2025-01-07 20:24:53 -08:00
committed by GitHub
parent 516ded618b
commit d1766f2c70
5 changed files with 226 additions and 74 deletions

View File

@@ -4,6 +4,8 @@
using namespace metal;
constant bool has_mask [[function_constant(20)]];
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
@@ -15,6 +17,9 @@ template <typename T, int D>
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -39,6 +44,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
}
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
@@ -54,34 +62,39 @@ template <typename T, int D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!has_mask || mask[0]) {
// Read the key
for (int j = 0; j < elem_per_thread; j++) {
k[j] = keys[j];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int j = 0; j < elem_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int j = 0; j < elem_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j];
}
}
// Move the pointers to the next kv
keys += stride;
values += stride;
if (has_mask) {
mask += BN * mask_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.
@@ -126,6 +139,9 @@ template <typename T, int D>
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -155,6 +171,10 @@ template <typename T, int D>
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
}
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
@@ -171,34 +191,39 @@ template <typename T, int D>
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!has_mask || mask[0]) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
if (has_mask) {
mask += BN * blocks * mask_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.

View File

@@ -1,6 +1,5 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -116,7 +115,8 @@ void sdpa_vector(
const array& k,
const array& v,
array& out,
float scale) {
float scale,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -134,9 +134,16 @@ void sdpa_vector(
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -149,6 +156,14 @@ void sdpa_vector(
compute_encoder.set_bytes(k_stride, 6);
compute_encoder.set_bytes(v_stride, 7);
compute_encoder.set_bytes(scale, 8);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 9);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 10);
compute_encoder.set_bytes(head_stride, 11);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -161,7 +176,8 @@ void sdpa_vector_2pass(
const array& k,
const array& v,
array& out,
float scale) {
float scale,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -198,9 +214,17 @@ void sdpa_vector_2pass(
d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -215,6 +239,14 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(k_stride, 8);
compute_encoder.set_bytes(v_stride, 9);
compute_encoder.set_bytes(scale, 10);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 12);
compute_encoder.set_bytes(head_stride, 13);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -247,8 +279,6 @@ void sdpa_vector_2pass(
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() == 3);
auto& s = stream();
auto& d = metal::device(s.device);
@@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
const auto& q = copy_unless(is_contiguous, q_pre);
// 1, heads, seq_len, head_dim
// mask [1, query_heads, 1, seq_len]
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
@@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
} else {
sdpa_vector(s, d, q, k, v, o, scale_);
sdpa_vector(s, d, q, k, v, o, scale_, mask);
}
}

View File

@@ -609,27 +609,32 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (mask && promote_types((*mask).dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
if (mask) {
// Check type
if (promote_types(mask->dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
// Check shape
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
<< " does not broadcast to implicit scores with shape " << mask_shape
<< ".";
throw std::invalid_argument(msg.str());
}
}
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
/* generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives:
* * CPU implementation
* * batch size > 1 for decoding or causal attention
* * query sequence length > 1 for decoding
* * query sequence length > 16 && non-null mask (causal attention)
* * non-null mask
* * dtype is not fp32 or fp16
*/
/* Generic implementation for use cases that Metal implementation does not
* support. */
int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
@@ -690,27 +695,27 @@ array scaled_dot_product_attention(
!mask.has_value() && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supported_mask = !mask || (mask->dtype() == bool_);
const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim &&
supported_mask && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
implementation_supports_use_case &=
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};
if (mask) {
inputs.push_back(*mask);
}
if (implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
{q, k, v});
}
if (mask.has_value()) {
return fallback({q, k, v, mask.value()})[0];
} else {
return fallback({q, k, v})[0];
std::move(inputs));
}
return fallback(inputs)[0];
}
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {