mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Matrix Attention kernel (#1610)
* Rough INIT * [WIP]: Loading and Matmuls added * [WIP]: Reductions and min working aligned kernel at headdim = 64 * [WIP] Added headdim 80 for testing * [WIP] Update dispatch params for testing * [WIP] Add support for unaligned seq lengths - still looks messy * Update sdpa_benchmarks * Update sdpa_benchmarks * Update sdpa_benchmarks * Enable gqa support * Update benchmark and switch off 128 headdim * Update headdim 128 tuning * Remove older fast attention code. Write out O strided * Disable hd=128 until further optimizations * Enable bf16 * Fix data size bug * Enable attn build outside of jit
This commit is contained in:
@@ -6,7 +6,9 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -19,122 +21,89 @@ void sdpa_full_self_attention_metal(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float alpha,
|
||||
array& out) {
|
||||
std::ostringstream kname_self_attention;
|
||||
kname_self_attention << "steel_gemm_attention_";
|
||||
const float scale,
|
||||
array& o) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
constexpr const int bm = 16;
|
||||
constexpr const int bn = 16;
|
||||
const int bk = q.shape(-1); // already forced to be 64 or 128
|
||||
int wm = 4;
|
||||
int wn = 1;
|
||||
|
||||
if (bk != 64 && bk != 128) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
|
||||
}
|
||||
int bd = q.shape(-1);
|
||||
int bq = 32;
|
||||
int bk = bd < 128 ? 32 : 16;
|
||||
|
||||
constexpr const int wm = 2;
|
||||
constexpr const int wn = 2;
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
int D = q.shape(3);
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
|
||||
std::string delimiter = "_";
|
||||
int qL = q.shape(2);
|
||||
int kL = k.shape(2);
|
||||
|
||||
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
|
||||
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
|
||||
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
|
||||
for (const auto& arr : {k, v, out}) {
|
||||
if (arr.dtype() != q.dtype()) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
|
||||
}
|
||||
}
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
|
||||
if (q.dtype() == float32) {
|
||||
kname_self_attention << "itype" + delimiter + "float";
|
||||
} else if (q.dtype() == float16) {
|
||||
kname_self_attention << "itype" + delimiter + "half";
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
|
||||
}
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_self_attention.str());
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
uint hidden_dim = q.shape(-1);
|
||||
uint qseq = q.shape(-2);
|
||||
uint qheads = q.shape(-3);
|
||||
const int NQ = (qL + bq - 1) / bq;
|
||||
const int NK = (kL + bk - 1) / bk;
|
||||
|
||||
const uint64_t KV_sequence_length = k.shape(-2);
|
||||
const uint query_sequence_length = q.shape(-2);
|
||||
const uint n_q_heads = q.shape(1);
|
||||
const uint n_kv_heads = k.shape(1);
|
||||
const int NQ_aligned = qL / bq;
|
||||
const int NK_aligned = kL / bk;
|
||||
|
||||
const int M = q.shape(-2);
|
||||
const int N = M;
|
||||
const int K = q.shape(-1);
|
||||
const size_t batch_size_out = q.shape(0) * q.shape(1);
|
||||
AttnParams params{
|
||||
/* int B = */ B,
|
||||
/* int H = */ H,
|
||||
/* int D = */ D,
|
||||
|
||||
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
|
||||
const int dk = q.shape(-1);
|
||||
const int ldq = dk;
|
||||
const int ldk = dk;
|
||||
const int ldv = dk;
|
||||
const int lds = bn;
|
||||
const int ldo = dk;
|
||||
/* int qL = */ qL,
|
||||
/* int kL = */ kL,
|
||||
|
||||
int tn = 1;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
/* int gqa_factor = */ gqa_factor,
|
||||
/* float scale = */ scale,
|
||||
|
||||
const int batch_stride_q = dk * query_sequence_length;
|
||||
const int batch_stride_k = dk * query_sequence_length;
|
||||
const int batch_stride_v = dk * query_sequence_length;
|
||||
const int batch_stride_o = dk * query_sequence_length;
|
||||
const int swizzle_log = 0;
|
||||
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
|
||||
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
|
||||
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
|
||||
const int batch_ndim = int(batch_shape.size());
|
||||
/* int NQ = */ NQ,
|
||||
/* int NK = */ NK,
|
||||
|
||||
MLXFastAttentionParams params{
|
||||
(int)M,
|
||||
(int)N,
|
||||
(int)K,
|
||||
ldq,
|
||||
ldk,
|
||||
ldv,
|
||||
lds,
|
||||
ldo,
|
||||
tn,
|
||||
tm,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o,
|
||||
swizzle_log,
|
||||
gemm_n_iterations_aligned,
|
||||
gemm_k_iterations_aligned,
|
||||
gemm_sv_m_block_iterations,
|
||||
batch_ndim,
|
||||
alpha};
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
const std::vector<size_t> batch_strides = {
|
||||
(size_t)batch_stride_q,
|
||||
(size_t)batch_stride_k,
|
||||
(size_t)batch_stride_v,
|
||||
(size_t)batch_stride_o};
|
||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -356,7 +325,24 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
size_t str_oD = 1;
|
||||
size_t str_oH = o.shape(3);
|
||||
size_t str_oL = o.shape(1) * str_oH;
|
||||
size_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user