Support fused masking in Attention (#1924)

* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
This commit is contained in:
Jagrit Digani
2025-03-20 11:01:32 -07:00
committed by GitHub
parent 3c164fca8c
commit 9adcd1a650
11 changed files with 504 additions and 148 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2024-25 Apple Inc.
using namespace mlx::steel;
@@ -9,6 +9,9 @@ using namespace mlx::steel;
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
template <typename T>
struct TransformScale {
T scale;
@@ -69,6 +72,7 @@ template <
int BD,
int WM,
int WN,
typename MaskType = float,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
@@ -76,6 +80,8 @@ template <
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -102,6 +108,11 @@ template <
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch
tidl.y * mask_params->M_strides[1]; // Head
}
// Prepare threadgroup memory
constexpr short padQ = 16 / sizeof(T);
constexpr short padK = 16 / sizeof(T);
@@ -203,7 +214,7 @@ template <
// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
loader_q.load_safe(short2(BD, params->qL_rem));
} else {
loader_q.load_unsafe();
}
@@ -221,12 +232,19 @@ template <
max_score[i] = Limits<AccumType>::min;
}
int kb_lim = params->NK;
if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
}
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
for (int kb = 0; kb < kb_lim; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
loader_k.load_safe(short2(BD, params->kL_rem));
} else {
loader_k.load_unsafe();
}
@@ -250,12 +268,11 @@ template <
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out of length sequence
// Mask out length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -264,7 +281,7 @@ template <
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
if ((col_pos + jj) >= params->kL_rem) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
@@ -272,11 +289,78 @@ template <
}
}
// Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
const int row_pos =
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if (row_pos < (col_pos + jj)) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
// Other masking as needed
if (has_mask) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
using frag_t = typename MMAFrag_mask_t::frag_type;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
frag_t mfrag;
MMAFrag_mask_t::load_safe(
mfrag,
mask,
int(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,
row_pos,
col_pos);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
if constexpr (is_bool) {
Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
}
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
loader_v.load_safe(short2(BD, params->kL_rem));
} else {
loader_v.load_unsafe();
}
@@ -367,8 +451,7 @@ template <
O += (tm + sm) * params->O_strides[2] + sn;
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;

View File

@@ -1,4 +1,4 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2024-25 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
@@ -6,26 +6,23 @@
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
instantiate_kernel( \
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
"_wm" #wm "_wn" #wn "_mask" #mname, \
attention, dtype, bq, bk, bd, wm, wn, mtype, float)
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
#define instantiate_attn_mask_helper(iname, itype) \
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
instantiate_attn_shapes_helper(float32, float);
instantiate_attn_mask_helper(float16, half);
instantiate_attn_mask_helper(bfloat16, bfloat16_t);
instantiate_attn_mask_helper(float32, float);
// clang-format on

View File

@@ -111,7 +111,7 @@ struct BaseMMAFrag<T, 8, 8> {
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}

View File

@@ -26,11 +26,19 @@ struct AttnParams {
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
int qL_rem; ///< Remainder in last query block
int kL_rem; ///< Remainder in last key/value block
int qL_off; ///< Offset in query sequence start
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
struct AttnMaskParams {
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
};
} // namespace steel
} // namespace mlx

View File

@@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
const array& k,
const array& v,
const float scale,
array& o) {
array& o,
bool do_causal_ = false,
const std::optional<array>& mask = std::nullopt) {
using namespace mlx::steel;
int wm = 4;
@@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
const bool has_mask = !!mask;
const bool do_causal = do_causal_;
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
};
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301}};
std::ostringstream kname;
// clang-format off
@@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm << "_wn" << wn; // clang-format on
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
@@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
/* int NQ_aligned = */ NQ_aligned,
/* int NK_aligned = */ NK_aligned,
/* int qL_rem = */ (qL - NQ_aligned * bq),
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
@@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
if (mask) {
auto m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};
compute_encoder.set_bytes(mask_params, 5);
compute_encoder.set_input_array(m, 6);
}
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
@@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu(
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(3) == 1;
return arr.strides(-1) == 1;
};
// We are in vector mode ie single query
@@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu(
{str_oB, str_oH, str_oL, str_oD},
flags);
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt;
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
}
d.add_temporaries(std::move(copies), s.index);

View File

@@ -567,7 +567,7 @@ array scaled_dot_product_attention(
const array& keys,
const array& values,
const float scale,
const std::optional<array>& mask,
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
const std::optional<int> memory_efficient_threshold,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
@@ -578,10 +578,29 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
}
if (mask && (*mask).ndim() > 4) {
bool do_causal = false;
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
bool has_bool_mask = false;
if (has_str_mask) {
if (std::get<std::string>(mask) != "causal") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '"
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
throw std::invalid_argument(msg.str());
} else {
do_causal = true;
}
}
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] the mask with shape "
<< (*mask).shape() << " expected to have at most rank 4";
<< (std::get<array>(mask)).shape()
<< " expected to have at most rank 4";
throw std::invalid_argument(msg.str());
}
@@ -631,9 +650,11 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (mask) {
if (has_arr_mask) {
// Check type
if (promote_types(mask->dtype(), final_type) != final_type) {
auto mask_arr = std::get<array>(mask);
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
@@ -642,9 +663,10 @@ array scaled_dot_product_attention(
// Check shape
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
msg << "[scaled_dot_product_attention] Mask with shape "
<< mask_arr.shape()
<< " does not broadcast to implicit scores with shape " << mask_shape
<< ".";
throw std::invalid_argument(msg.str());
@@ -662,7 +684,7 @@ array scaled_dot_product_attention(
threshold = std::max(1, memory_efficient_threshold.value());
}
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
@@ -676,9 +698,21 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3) {
if (inputs.size() > 3 || do_causal) {
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs[3];
auto mask = inputs.back();
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
mask = greater_equal(q_idx, k_idx, s);
}
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
mask = expand_dims(mask, -3, s);
@@ -702,9 +736,10 @@ array scaled_dot_product_attention(
};
auto stream = to_stream(s);
const size_t value_head_dim = v.shape(-1);
const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2);
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
@@ -712,27 +747,33 @@ array scaled_dot_product_attention(
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= k.shape(-2)) &&
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};
if (mask) {
inputs.push_back(*mask);
if (has_arr_mask) {
inputs.push_back(std::get<array>(mask));
}
if (implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal),
std::move(inputs));
}
return fallback(inputs)[0];
@@ -741,7 +782,7 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_;
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
}
array pack_and_quantize(

View File

@@ -3,6 +3,7 @@
#pragma once
#include <optional>
#include <variant>
#include "mlx/utils.h"
@@ -47,7 +48,7 @@ array scaled_dot_product_attention(
const array& keys,
const array& values,
const float scale,
const std::optional<array>& mask = std::nullopt,
const std::variant<std::monostate, std::string, array>& mask = {},
const std::optional<int> memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});

View File

@@ -206,8 +206,9 @@ class ScaledDotProductAttention : public Custom {
explicit ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale)
: Custom(stream, fallback), scale_(scale) {}
const float scale,
const bool do_causal)
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -225,12 +226,13 @@ class ScaledDotProductAttention : public Custom {
DEFINE_PRINT(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_pair(nullptr, scale_);
return std::make_tuple(nullptr, scale_, do_causal_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_;
bool do_causal_;
};
class AffineQuantize : public Custom {