Matrix Attention kernel (#1610)

* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
This commit is contained in:
Jagrit Digani
2024-11-22 10:34:05 -08:00
committed by GitHub
parent c79f6a4a8c
commit 02bec0bb6d
14 changed files with 2049 additions and 1109 deletions

View File

@@ -6,7 +6,9 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
@@ -19,122 +21,89 @@ void sdpa_full_self_attention_metal(
const array& q,
const array& k,
const array& v,
const float alpha,
array& out) {
std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
const float scale,
array& o) {
using namespace mlx::steel;
constexpr const int bm = 16;
constexpr const int bn = 16;
const int bk = q.shape(-1); // already forced to be 64 or 128
int wm = 4;
int wn = 1;
if (bk != 64 && bk != 128) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
}
int bd = q.shape(-1);
int bq = 32;
int bk = bd < 128 ? 32 : 16;
constexpr const int wm = 2;
constexpr const int wn = 2;
int B = q.shape(0);
int H = q.shape(1);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
std::string delimiter = "_";
int qL = q.shape(2);
int kL = k.shape(2);
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
};
if (q.dtype() == float32) {
kname_self_attention << "itype" + delimiter + "float";
} else if (q.dtype() == float16) {
kname_self_attention << "itype" + delimiter + "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm << "_wn" << wn; // clang-format on
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_self_attention.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
uint hidden_dim = q.shape(-1);
uint qseq = q.shape(-2);
uint qheads = q.shape(-3);
const int NQ = (qL + bq - 1) / bq;
const int NK = (kL + bk - 1) / bk;
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
const int NQ_aligned = qL / bq;
const int NK_aligned = kL / bk;
const int M = q.shape(-2);
const int N = M;
const int K = q.shape(-1);
const size_t batch_size_out = q.shape(0) * q.shape(1);
AttnParams params{
/* int B = */ B,
/* int H = */ H,
/* int D = */ D,
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
const int dk = q.shape(-1);
const int ldq = dk;
const int ldk = dk;
const int ldv = dk;
const int lds = bn;
const int ldo = dk;
/* int qL = */ qL,
/* int kL = */ kL,
int tn = 1;
int tm = (M + bm - 1) / bm;
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
const int batch_stride_q = dk * query_sequence_length;
const int batch_stride_k = dk * query_sequence_length;
const int batch_stride_v = dk * query_sequence_length;
const int batch_stride_o = dk * query_sequence_length;
const int swizzle_log = 0;
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
const int batch_ndim = int(batch_shape.size());
/* int NQ = */ NQ,
/* int NK = */ NK,
MLXFastAttentionParams params{
(int)M,
(int)N,
(int)K,
ldq,
ldk,
ldv,
lds,
ldo,
tn,
tm,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
swizzle_log,
gemm_n_iterations_aligned,
gemm_k_iterations_aligned,
gemm_sv_m_block_iterations,
batch_ndim,
alpha};
/* int NQ_aligned = */ NQ_aligned,
/* int NK_aligned = */ NK_aligned,
const std::vector<size_t> batch_strides = {
(size_t)batch_stride_q,
(size_t)batch_stride_k,
(size_t)batch_stride_v,
(size_t)batch_stride_o};
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -356,7 +325,24 @@ void ScaledDotProductAttention::eval_gpu(
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
size_t str_oD = 1;
size_t str_oH = o.shape(3);
size_t str_oL = o.shape(1) * str_oH;
size_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}