diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 58c0866be..c8f1cfcf5 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -6,18 +6,8 @@ using namespace mlx::steel; // GEMM kernels /////////////////////////////////////////////////////////////////////////////// -constant bool has_batch [[function_constant(10)]]; - -constant bool use_out_source [[function_constant(100)]]; -constant bool do_axpby [[function_constant(110)]]; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -constant bool do_gather [[function_constant(300)]]; - -constant bool gather_bias = do_gather && use_out_source; +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; template struct TransformScale { @@ -204,7 +194,11 @@ template < threadgroup_barrier(mem_flags::mem_threadgroup); // Load Q blocks apply scale - loader_q.load_unsafe(); + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ)); + } else { + loader_q.load_unsafe(); + } loader_q.apply_inplace_op(ts); // Init row reduction variables @@ -223,7 +217,11 @@ template < for (int kb = 0; kb < params->NK; kb++) { // Load K block and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); - loader_k.load_unsafe(); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + } else { + loader_k.load_unsafe(); + } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -243,10 +241,36 @@ template < tile_matmad(Stile, Qtile, Ktile, Stile); } + // Mask out of length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + const short lim = params->kL - params->NK_aligned * BK; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= lim) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + simdgroup_barrier(mem_flags::mem_none); // Load V blocks - loader_v.load_unsafe(); + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + } else { + loader_v.load_unsafe(); + } // Do softmax @@ -309,5 +333,16 @@ template < // Store results O += (tm + sm) * params->O_strides[2] + sn; - Otile.template store(O, params->O_strides[2]); + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = + short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); + } else { + Otile.template store(O, params->O_strides[2]); + } } diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index e52b6f23a..d9e2ce2ca 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -21,11 +21,9 @@ uint3 lid [[thread_position_in_threadgroup]]); #define instantiate_attn_shapes_helper(iname, itype) \ - instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ - instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \ - instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \ - instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \ - + instantiate_attn(iname, itype, 32, 32, 128, 4, 1) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1) instantiate_attn_shapes_helper(float16, half); // instantiate_attn_shapes_helper(bfloat16, bfloat16_t); diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index d460c523d..a9d7c7b4a 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -12,15 +12,20 @@ namespace steel { struct AttnParams { int B; ///< Batch Size int H; ///< Heads - int L; ///< Sequence Length int D; ///< Head Dim + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + int gqa_factor; ///< Group Query factor float scale; ///< Attention scale int NQ; ///< Number of query blocks int NK; ///< Number of key/value blocks + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) size_t K_strides[3]; ///< Key strides (B, H, L, D = 1) size_t V_strides[3]; ///< Value strides (B, H, L, D = 1) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 5aac270f8..8080c0f6f 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -33,34 +33,66 @@ void sdpa_full_self_attention_metal( int bk = 32; int bd = q.shape(-1); - std::ostringstream kname; - kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk - << "_bd" << bd << "_wm" << wm << "_wn" << wn; - - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - int B = q.shape(0); int H = q.shape(1); - int L = q.shape(2); int D = q.shape(3); int gqa_factor = q.shape(1) / k.shape(1); - int NQ = (L + bq - 1) / bq; - int NK = (L + bk - 1) / bk; + int qL = q.shape(2); + int kL = k.shape(2); + + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; + + metal::MTLFCList func_consts = { + {&align_Q, MTL::DataType::DataTypeBool, 200}, + {&align_K, MTL::DataType::DataTypeBool, 201}, + }; + + std::ostringstream kname; + // clang-format off + kname << "steel_attention_" + << type_to_name(q) + << "_bq" << bq + << "_bk" << bk + << "_bd" << bd + << "_wm" << wm << "_wn" << wn; // clang-format on + + std::string base_name = kname.str(); + + // clang-format off + kname << "_align_Q_" << (align_Q ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + const int NQ = (qL + bq - 1) / bq; + const int NK = (kL + bk - 1) / bk; + + const int NQ_aligned = qL / bq; + const int NK_aligned = kL / bk; AttnParams params{ /* int B = */ B, /* int H = */ H, - /* int L = */ L, /* int D = */ D, + + /* int qL = */ qL, + /* int kL = */ kL, + /* int gqa_factor = */ gqa_factor, /* float scale = */ scale, /* int NQ = */ NQ, /* int NK = */ NK, + /* int NQ_aligned = */ NQ_aligned, + /* int NK_aligned = */ NK_aligned, + /* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},