Add boolean mask support in vector SDPA (#1757)

This commit is contained in:
Awni Hannun
2025-01-07 20:24:53 -08:00
committed by GitHub
parent 516ded618b
commit d1766f2c70
5 changed files with 226 additions and 74 deletions

View File

@@ -1,6 +1,5 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -116,7 +115,8 @@ void sdpa_vector(
const array& k,
const array& v,
array& out,
float scale) {
float scale,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -134,9 +134,16 @@ void sdpa_vector(
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -149,6 +156,14 @@ void sdpa_vector(
compute_encoder.set_bytes(k_stride, 6);
compute_encoder.set_bytes(v_stride, 7);
compute_encoder.set_bytes(scale, 8);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 9);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 10);
compute_encoder.set_bytes(head_stride, 11);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -161,7 +176,8 @@ void sdpa_vector_2pass(
const array& k,
const array& v,
array& out,
float scale) {
float scale,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -198,9 +214,17 @@ void sdpa_vector_2pass(
d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -215,6 +239,14 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(k_stride, 8);
compute_encoder.set_bytes(v_stride, 9);
compute_encoder.set_bytes(scale, 10);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 12);
compute_encoder.set_bytes(head_stride, 13);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -247,8 +279,6 @@ void sdpa_vector_2pass(
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() == 3);
auto& s = stream();
auto& d = metal::device(s.device);
@@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
const auto& q = copy_unless(is_contiguous, q_pre);
// 1, heads, seq_len, head_dim
// mask [1, query_heads, 1, seq_len]
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
@@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
} else {
sdpa_vector(s, d, q, k, v, o, scale_);
sdpa_vector(s, d, q, k, v, o, scale_, mask);
}
}