mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -116,7 +115,8 @@ void sdpa_vector(
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
const std::optional<array>& mask) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -134,9 +134,16 @@ void sdpa_vector(
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@@ -149,6 +156,14 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(k_stride, 6);
|
||||
compute_encoder.set_bytes(v_stride, 7);
|
||||
compute_encoder.set_bytes(scale, 8);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 9);
|
||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
||||
compute_encoder.set_bytes(seq_stride, 10);
|
||||
compute_encoder.set_bytes(head_stride, 11);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -161,7 +176,8 @@ void sdpa_vector_2pass(
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
const std::optional<array>& mask) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -198,9 +214,17 @@ void sdpa_vector_2pass(
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@@ -215,6 +239,14 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(k_stride, 8);
|
||||
compute_encoder.set_bytes(v_stride, 9);
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
||||
compute_encoder.set_bytes(seq_stride, 12);
|
||||
compute_encoder.set_bytes(head_stride, 13);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -247,8 +279,6 @@ void sdpa_vector_2pass(
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
const auto& q = copy_unless(is_contiguous, q_pre);
|
||||
// 1, heads, seq_len, head_dim
|
||||
// mask [1, query_heads, 1, seq_len]
|
||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
@@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
auto mask =
|
||||
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
// - The sequence length is even longer and we have gqa
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, mask);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user