diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index 0e81b6dfa..23383475e 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -1,62 +1,189 @@ +# Copyright © 2024 Apple Inc. + import argparse import math +import os +import subprocess +import time import mlx.core as mx -from time_utils import time_fn +import numpy as np -MAX_SEQ = 300 -START_SEQ = 100 -SEQ_INCREMENT = 50 +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) +device_name = device_name.decode("utf-8").strip("\n") + +N_warmup = 5 +N_iter_bench = 40 +N_iter_func = 8 -def time_self_attention_primitives(): - mx.random.seed(3) - B = 2 - H = 38 - D = 64 - for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT): - q = mx.random.uniform(shape=(B, H, R, D)) - k = mx.random.uniform(shape=(B, H, R, D)) - v = mx.random.uniform(shape=(B, H, R, D)) - scale = 1.0 / math.sqrt(float(D)) - mx.eval(q, k, v) +def bench(f, *args): + for i in range(N_warmup): + f(*args) - def sdpa_primitives(qs, ks, vs, alpha): - s = (alpha * qs) @ ks.transpose(0, 1, 3, 2) - p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) - o = p @ vs - return o - - time_fn(sdpa_primitives, q, k, v, scale) + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(*args) + e = time.perf_counter_ns() + return (e - s) * 1e-9 -def time_self_attention_sdpa(): - mx.random.seed(3) - B = 2 - H = 38 - D = 64 - for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT): - q = mx.random.uniform(shape=(B, H, R, D)) - k = mx.random.uniform(shape=(B, H, R, D)) - v = mx.random.uniform(shape=(B, H, R, D)) - scale = 1.0 / math.sqrt(float(D)) - mx.eval(q, k, v) +def mlx_sdpa_fused_inner(q, k, v, scale): + return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) - def sdpa_fused(qs, ks, vs, alpha): - o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha) - return o - time_fn(sdpa_fused, q, k, v, scale) +def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): + q_dtype = q.dtype + q = q * mx.array(scale, q_dtype) + n_q_heads = q.shape[-3] + n_kv_heads = k.shape[-3] + n_repeats = n_q_heads // n_kv_heads + + B = q.shape[0] + L = q.shape[2] + + if n_repeats > 1: + q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) + k = mx.expand_dims(k, 2) + v = mx.expand_dims(v, 2) + + scores = q @ mx.swapaxes(k, -1, -2) + if f32softmax: + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype) + else: + scores = mx.softmax(scores, axis=-1) + + out = scores @ v + if n_repeats > 1: + out = mx.reshape(out, [B, n_q_heads, L, -1]) + + return out + + +def mlx_spda_unfused(q, k, v, scale, transpose): + q_out = q + if transpose: + k = mx.transpose(k, (0, 2, 1, 3)) + v = mx.transpose(v, (0, 2, 1, 3)) + + for i in range(N_iter_func): + if transpose: + q_out = mx.transpose(q_out, (0, 2, 1, 3)) + q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale) + if transpose: + q_out = mx.transpose(q_out, (0, 2, 1, 3)) + + mx.eval(q_out) + return q_out + + +def mlx_spda_fused(q, k, v, scale, transpose): + q_out = q + if transpose: + k = mx.transpose(k, (0, 2, 1, 3)) + v = mx.transpose(v, (0, 2, 1, 3)) + + for i in range(N_iter_func): + if transpose: + q_out = mx.transpose(q_out, (0, 2, 1, 3)) + q_out = mlx_sdpa_fused_inner(q_out, k, v, scale) + if transpose: + q_out = mx.transpose(q_out, (0, 2, 1, 3)) + + mx.eval(q_out) + return q_out + + +def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True): + shape_q = ( + (B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim) + ) + shape_kv = ( + (B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim) + ) + + q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) + v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) + + scale = math.sqrt(1.0 / head_dim) + + q_mx = mx.array(q_np) + k_mx = mx.array(k_np) + v_mx = mx.array(v_np) + + time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose) + time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose) + + if transpose: + q_mx = mx.transpose(q_mx, (0, 2, 1, 3)) + k_mx = mx.transpose(k_mx, (0, 2, 1, 3)) + v_mx = mx.transpose(v_mx, (0, 2, 1, 3)) + + o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale) + o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True) + + atol = 1e-5 if np_dtype == np.float32 else 1e-4 + + if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol): + print( + f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" + ) + + return time_mlx_fused, time_mlx_unfused + + +def get_gflop_count(B, M, N, K): + return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) if __name__ == "__main__": - parser = argparse.ArgumentParser("MLX benchmarks.") - parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") - args = parser.parse_args() - if args.gpu: - mx.set_default_device(mx.gpu) - else: - mx.set_default_device(mx.cpu) + parser = argparse.ArgumentParser(description="Run gemm benchmarks") - time_self_attention_sdpa() - time_self_attention_primitives() + dtypes = ("float16", "float32")[:1] + transposes = (False,) + + # fmt: off + shapes_64 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 32, 32, 64, 32, 32), + ( 1, 64, 64, 64, 32, 32), + ( 1, 128, 128, 64, 32, 32), + ( 1, 256, 256, 64, 32, 32), + ( 1, 512, 512, 64, 32, 32), + ( 1, 1024, 1024, 64, 32, 32), + ( 1, 2048, 2048, 64, 32, 32), + ( 1, 4096, 4096, 64, 32, 32), + ) + + shapes_80 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 1024, 1024, 80, 32, 32), + ( 1, 2048, 2048, 80, 32, 32), + ( 1, 4096, 4096, 80, 32, 32), + ) + + shapes_128 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 1024, 1024, 128, 32, 32), + ( 1, 2048, 2048, 128, 32, 32), + ( 1, 4096, 4096, 128, 32, 32), + ) + # fmt: on + + shapes = shapes_64 + shapes_80 + shapes_128 + + print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") + + for dtype in dtypes: + for transpose in transposes: + for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: + np_dtype = getattr(np, dtype) + time_mlx_fused, time_mlx_unfused = bench_shape( + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose + ) + diff = time_mlx_unfused / time_mlx_fused - 1.0 + t_str = 1 if transpose else 0 + print( + f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" + ) diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 978475c53..9c0856425 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -44,9 +44,7 @@ build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) build_kernel(rope) -build_kernel( - scaled_dot_product_attention scaled_dot_product_attention_params.h - sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h) +build_kernel(scaled_dot_product_attention sdpa_vector.h) set(STEEL_HEADERS steel/defines.h @@ -68,6 +66,24 @@ set(STEEL_HEADERS steel/utils/type_traits.h steel/utils/integral_constant.h) +set(STEEL_ATTN_HEADERS + steel/defines.h + steel/utils.h + steel/gemm/gemm.h + steel/gemm/mma.h + steel/gemm/loader.h + steel/gemm/transforms.h + steel/utils/type_traits.h + steel/utils/integral_constant.h + steel/attn/attn.h + steel/attn/loader.h + steel/attn/mma.h + steel/attn/params.h + steel/attn/transforms.h + steel/attn/kernels/steel_attention.h) + +build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS}) + if(NOT MLX_METAL_JIT) build_kernel(arange arange.h) build_kernel(binary binary.h binary_ops.h) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 014231285..b5bc9607e 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,930 +1,11 @@ #include -#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/backend/metal/kernels/sdpa_vector.h" -#include "mlx/backend/metal/kernels/steel/defines.h" -#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/utils.h" using namespace metal; -using namespace mlx::steel; - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short alignment = 1, - short n_reads = (BCOLS * BROWS) / (tgp_size), - short TCOLS = BCOLS / n_reads, - short TROWS = tgp_size / TCOLS> -struct BlockLoaderFA { - STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; - STEEL_CONST short vec_size = n_reads; - - // Leading dimension for src - const int src_ld; - const int tile_stride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - struct alignas(alignment * sizeof(T)) ReadVector { - uint8_t v[sizeof(T) * vec_size]; - }; - - /* Constructor */ - METAL_FUNC BlockLoaderFA( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - *((threadgroup ReadVector*)(&dst[i * dst_ld])) = - *((const device ReadVector*)(&src[i * src_ld])); - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = src_tile_dim - short2(bj, bi); - - // Skip loading if thread has no valid reads - if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - return; - } - - // Use fast thread memory for bound checks - bool tmp_idx[vec_size]; - T tmp_val[vec_size]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - // Make sure tmp_idx only contains valid indices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); - } - - // Read valid indices into tmp_val - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; - } - - // Zero out uneeded values - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = tmp_val[j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tile_stride; - } - METAL_FUNC void next(short n) { - src += n * tile_stride; - } -}; - -template -struct LoopAlignment {}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMAFA { - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; - - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - // Offsets within threadgroup - const short tm; - const short tn; - - short sm; - short sn; - - ushort sid; - ushort slid; - - short As_offset; - short Bs_offset; - - /* Constructor */ - METAL_FUNC BlockMMAFA( - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - slid = simd_lane_id; - sid = simd_group_id; - - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; - - // Iterate over BK in blocks of 8 - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup A as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Multiply and accumulate into result simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; - } - } - - METAL_FUNC void rescale_output(const threadgroup float* Corrections) { - // Loop over all simdgroup tiles - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - short row = sm + tm + i * TM_stride; - float scale_value = Corrections[row]; - - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - // int offset = (i * TM_stride) * ldc + (j * TN_stride); - accum[0] *= scale_value; - accum[1] *= scale_value; - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* C, const int ldc) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + tn + sn; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; - - // Write out C - C[offset] = outs[0]; - C[offset + 1] = outs[1]; - } - } - } - - METAL_FUNC void store_result_to_tgp_memory( - threadgroup U* C, - const int ldc, - short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } - - METAL_FUNC void - store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; - } - } - } - - METAL_FUNC void store_result_safe( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); - } - } - } - } - } - - METAL_FUNC void clear_results() { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - results[i * TN + j] = simdgroup_matrix(0); - } - } - } -}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct FastAttentionKernel { - STEEL_CONST short tgp_padding = 16 / sizeof(T); - STEEL_CONST short float_padding = 16 / sizeof(float); - STEEL_CONST short tgp_mem_size_q = - transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_k = - transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_v = - transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); - - // maxes, rowsums, rescale - STEEL_CONST short tgp_mem_size_corrections = - 4 * (BM * sizeof(float) + float_padding); - - STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; - - STEEL_CONST short tgp_mem_size = share_kv_smem - ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections - : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections + tgp_mem_size_v; - - STEEL_CONST short tgp_size = WM * WN * 32; - - static_assert(transpose_q == false, "Expected Q not transposed."); - static_assert(transpose_k == true, "Expected K transposed."); - static_assert(transpose_v == false, "Expected V not transposed."); - static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); - - using loader_q_t = BlockLoaderFA< - T, - transpose_q ? BK : BM, - transpose_q ? BM : BK, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - !transpose_q, - tgp_size>; - - using loader_k_t = BlockLoaderFA< - T, - transpose_k ? BN : BK, - transpose_k ? BK : BN, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - transpose_k, - tgp_size>; - - using loader_v_t = BlockLoaderFA< - T, - transpose_v ? BK : BN, - transpose_v ? BN : BK, - transpose_v ? BN + tgp_padding : BK + tgp_padding, - transpose_v, - tgp_size>; - - using mma_qk_t = BlockMMAFA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - AccumType, - Epilogue>; - - using mma_sv_t = BlockMMAFA< - T, - U, - BM, - BK, - BN, - WM, - WN, - false, - transpose_v, - BN + tgp_padding, - BK + tgp_padding, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_k_t& loader_b, - thread mma_qk_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - (void)tgp_bm; - - short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - // not valid for gemm_k_iterations > 1 (so, BK == d_k) - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - } - - static METAL_FUNC void initialize_corrections( - threadgroup float* C, - uint simd_lane_id, - uint simd_group_id) { - if (simd_group_id == 0) { - threadgroup float* maxes = C; - threadgroup float* sums = C + (BM + float_padding); - threadgroup float* o_rescale = sums + (BM + float_padding); - threadgroup float* output_rescale = o_rescale + (BM + float_padding); - - if (simd_lane_id < BM) { - maxes[simd_lane_id] = -INFINITY; // m_i - sums[simd_lane_id] = 0.f; // l_i - o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) - output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i - } - } - } - - static METAL_FUNC void rescale_ss( - threadgroup T* Ss, - threadgroup float* Corrections, - uint simd_group_id, - uint simd_lane_id, - short2 local_blocks, - float alpha) { - if (simd_group_id == 0) { - short row_offset = BM + float_padding; - threadgroup float* maxes = Corrections; - threadgroup float* sums = Corrections + row_offset; - threadgroup float* o_rescale = sums + row_offset; - threadgroup float* output_scales = o_rescale + row_offset; - - if (simd_lane_id < uint(local_blocks.y)) { - float m_i_old = maxes[simd_lane_id]; - float l_i_old = sums[simd_lane_id]; - - float m_i_new = m_i_old; - float l_i_new = l_i_old; - - short offset = simd_lane_id * (BN + tgp_padding); - - float m_ij = -INFINITY; - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - m_ij = max(m_ij, val); - } - - m_i_new = max(m_ij, m_i_new); - - float rowsum = 0.f; // lij - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - float P_i_j = exp(val - m_ij); - rowsum += P_i_j; - P_i_j = P_i_j * exp(m_ij - m_i_new); - Ss[offset + j] = T(P_i_j); - } - - l_i_new = - exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; - maxes[simd_lane_id] = m_i_new; - sums[simd_lane_id] = l_i_new; - float rescale = l_i_old * exp(m_i_old - m_i_new); - o_rescale[simd_lane_id] = rescale; - output_scales[simd_lane_id] = 1.0 / l_i_new; - } - } - } - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device U* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - threadgroup T* Qs [[threadgroup(0)]], - threadgroup T* Ks [[threadgroup(1)]], - threadgroup T* Ss [[threadgroup(2)]], - threadgroup T* Vs [[threadgroup(3)]], - threadgroup float* Corrections [[threadgroup(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in Q, O; and head in K, V. - const int c_row = tid_y * BM; - - Q += transpose_q ? c_row : c_row * params->ldq; - thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); - - short tgp_bm = min(BM, params->M - c_row); - short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - loader_q.load_safe(tile_dims_Q); - - initialize_corrections(Corrections, simd_lane_id, simd_group_id); - - O += c_row * params->ldo; - - // Prepare threadgroup mma operation - thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); - thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); - thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); - thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); - - for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; - n_block++) { - short c_col = BN; - - // Prepare threadgroup loading operations - short gemm_k_iterations = params->gemm_k_iterations_aligned; - short tgp_bn_qk = min(BN, params->N - c_col * n_block); - threadgroup_barrier(mem_flags::mem_none); - - /////////////////////////////////////////////////////////////////////////////// - { // Loop over K - unaligned case - - if (tgp_bm == BM && tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } else if (tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else if (tgp_bm == BM) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } - } - - mma_qk_op.store_result_to_tgp_memory( - Ss, BN + tgp_padding, short2(BN, BM)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - rescale_ss( - Ss, - Corrections, - simd_group_id, - simd_lane_id, - short2(tgp_bn_qk, tgp_bm), - params->alpha); - - loader_v.load_safe(short2(BK, tgp_bn_qk)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); - mma_softmax_sv_op.rescale_output(o_scales); - - mma_softmax_sv_op.mma(Ss, Vs); - - threadgroup float* final_output_scales = - Corrections + 3 * (BM + float_padding); - - mma_softmax_sv_op.rescale_output(final_output_scales); - - loader_v.next(); - loader_k.next(BN); - - mma_qk_op.clear_results(); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); - } -}; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using attention_kernel = FastAttentionKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_v, - MN_aligned, - K_aligned>; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* Q_bstrides = batch_strides; - const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); - - Q += batch_offsets.x; - K += batch_offsets.y; - V += batch_offsets.y; - - } else { - Q += params->batch_stride_q * tid.z; - K += params->batch_stride_k * tid.z; - V += params->batch_stride_v * tid.z; - } - - // same shape as input - O += params->batch_stride_o * tid.z; - threadgroup T Qs[attention_kernel::tgp_mem_size_q]; - threadgroup T Ss[attention_kernel::tgp_mem_size_s]; - threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; - - if (attention_kernel::share_kv_smem) { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } else { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T Vs[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } -} - // clang-format off - -// SDPA full instantiations -#define instantiate_fast_inference_self_attention_kernel( \ - itype, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ - "_itype_" #itype)]] [[kernel]] void \ - attention( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - device otype* O [[buffer(3)]], \ - const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 64, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 128, - 2, - 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); - // SDPA vector instantiations #define instantiate_sdpa_vector(type, head_dim) \ instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h deleted file mode 100644 index a77dad268..000000000 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h +++ /dev/null @@ -1,42 +0,0 @@ -// -// scaled_dot_product_attention_params.h -// mlx - -#pragma once - -struct MLXFastAttentionParams { - const int M; - const int N; - const int K; - - const int ldq; // ldq == ldo - const int ldk; - const int ldv; - const int lds; - const int ldo; - - const int tiles_n; - const int tiles_m; - - const int batch_stride_q; - const int batch_stride_k; - const int batch_stride_v; - const int batch_stride_o; - - const int swizzle_log; - const int gemm_n_iterations_aligned; - const int gemm_k_iterations_aligned; - const int gemm_sv_m_block_iterations; - - const int batch_ndim; - const float alpha; -}; - -struct MLXScaledDotProductAttentionParams { - // Associated dimensions & transposition information - const uint QUERY_SEQUENCE_LENGTH = 1; - const uint N_Q_HEADS = 32; - const uint N_KV_HEADS = 32; - const uint KV_TILES = 1; - const float INV_ALPHA = 0.08838834764831843f; -}; diff --git a/mlx/backend/metal/kernels/steel/attn/attn.h b/mlx/backend/metal/kernels/steel/attn/attn.h new file mode 100644 index 000000000..991d4d69c --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/attn.h @@ -0,0 +1,296 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/attn/loader.h" +#include "mlx/backend/metal/kernels/steel/attn/mma.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h new file mode 100644 index 000000000..c5c69c30b --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -0,0 +1,349 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Seqeunce + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Seqeunce + + // Prepare threadgroup memory + constexpr short padQ = 0; // 16 / sizeof(T); + constexpr short padK = 0; // 16 / sizeof(T); + constexpr short padV = 0; // 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + threadgroup T Qs[BQ * (BD + padQ)]; + threadgroup T Ks[(BK + padK) * BD]; + threadgroup T Vs[BK * (BD + padV)]; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale)); + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks apply scale + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ)); + } else { + loader_q.load_unsafe(); + } + loader_q.apply_inplace_op(ts); + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::min; + } + + // Loop over KV seq length + for (int kb = 0; kb < params->NK; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + } else { + loader_k.load_unsafe(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do S = Q @ K.T + Stile.clear(); + + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Mask out of length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + const short lim = params->kL - params->NK_aligned * BK; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= lim) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + } else { + loader_v.load_unsafe(); + } + + // Do softmax + + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Stile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp(max_score[i] - new_max[i]); + } + + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); + + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } + + // Update O + Otile.template row_bin_op(factor); + + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); + Vtile.template load(&Vs[Vs_offset]); + + simdgroup_barrier(mem_flags::mem_none); + + // Do O = S @ V + tile_matmad(Otile, Stile, Vtile, Otile); + + // Prepare for next iteration + loader_k.next(); + loader_v.next(); + } + + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = + short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); + } else { + Otile.template store(O, params->O_strides[2]); + } +} diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal new file mode 100644 index 000000000..0d05a6932 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -0,0 +1,31 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/attn/attn.h" +#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \ + template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \ + [[kernel]] void attention( \ + const device dtype* Q [[buffer(0)]], \ + const device dtype* K [[buffer(1)]], \ + const device dtype* V [[buffer(2)]], \ + device dtype* O [[buffer(3)]],\ + const constant AttnParams* params [[buffer(4)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_attn_shapes_helper(iname, itype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1) + +instantiate_attn_shapes_helper(float16, half); +instantiate_attn_shapes_helper(bfloat16, bfloat16_t); + +instantiate_attn_shapes_helper(float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h new file mode 100644 index 000000000..2849c00f1 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -0,0 +1,264 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; + +template < + typename T, + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h new file mode 100644 index 000000000..5ddd37ac6 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -0,0 +1,726 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h new file mode 100644 index 000000000..a9d7c7b4a --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -0,0 +1,36 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Attn param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + size_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + size_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + size_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/transforms.h b/mlx/backend/metal/kernels/steel/attn/transforms.h new file mode 100644 index 000000000..c0624d21b --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/transforms.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 7ad6c0fe6..aea235abb 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -385,9 +385,9 @@ struct BlockMMA { STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; + STEEL_CONST short TM = BM / (kFragSize * WM); // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; + STEEL_CONST short TN = BN / (kFragSize * WN); // Threadgroup A strides STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 3edda16a3..079a0baff 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -6,7 +6,9 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" + +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/utils.h" @@ -19,122 +21,89 @@ void sdpa_full_self_attention_metal( const array& q, const array& k, const array& v, - const float alpha, - array& out) { - std::ostringstream kname_self_attention; - kname_self_attention << "steel_gemm_attention_"; + const float scale, + array& o) { + using namespace mlx::steel; - constexpr const int bm = 16; - constexpr const int bn = 16; - const int bk = q.shape(-1); // already forced to be 64 or 128 + int wm = 4; + int wn = 1; - if (bk != 64 && bk != 128) { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128"); - } + int bd = q.shape(-1); + int bq = 32; + int bk = bd < 128 ? 32 : 16; - constexpr const int wm = 2; - constexpr const int wn = 2; + int B = q.shape(0); + int H = q.shape(1); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); - std::string delimiter = "_"; + int qL = q.shape(2); + int kL = k.shape(2); - kname_self_attention << "bm_" + std::to_string(bm) + delimiter; - kname_self_attention << "bn_" + std::to_string(bn) + delimiter; - kname_self_attention << "bk_" + std::to_string(bk) + delimiter; + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; - for (const auto& arr : {k, v, out}) { - if (arr.dtype() != q.dtype()) { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o"); - } - } + metal::MTLFCList func_consts = { + {&align_Q, MTL::DataType::DataTypeBool, 200}, + {&align_K, MTL::DataType::DataTypeBool, 201}, + }; - if (q.dtype() == float32) { - kname_self_attention << "itype" + delimiter + "float"; - } else if (q.dtype() == float16) { - kname_self_attention << "itype" + delimiter + "half"; - } else { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16."); - } + std::ostringstream kname; + // clang-format off + kname << "steel_attention_" + << type_to_name(q) + << "_bq" << bq + << "_bk" << bk + << "_bd" << bd + << "_wm" << wm << "_wn" << wn; // clang-format on + + std::string base_name = kname.str(); + + // clang-format off + kname << "_align_Q_" << (align_Q ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname_self_attention.str()); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); - uint hidden_dim = q.shape(-1); - uint qseq = q.shape(-2); - uint qheads = q.shape(-3); + const int NQ = (qL + bq - 1) / bq; + const int NK = (kL + bk - 1) / bk; - const uint64_t KV_sequence_length = k.shape(-2); - const uint query_sequence_length = q.shape(-2); - const uint n_q_heads = q.shape(1); - const uint n_kv_heads = k.shape(1); + const int NQ_aligned = qL / bq; + const int NK_aligned = kL / bk; - const int M = q.shape(-2); - const int N = M; - const int K = q.shape(-1); - const size_t batch_size_out = q.shape(0) * q.shape(1); + AttnParams params{ + /* int B = */ B, + /* int H = */ H, + /* int D = */ D, - const std::vector batch_shape = {q.shape(0) * q.shape(1)}; - const int dk = q.shape(-1); - const int ldq = dk; - const int ldk = dk; - const int ldv = dk; - const int lds = bn; - const int ldo = dk; + /* int qL = */ qL, + /* int kL = */ kL, - int tn = 1; - int tm = (M + bm - 1) / bm; + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, - const int batch_stride_q = dk * query_sequence_length; - const int batch_stride_k = dk * query_sequence_length; - const int batch_stride_v = dk * query_sequence_length; - const int batch_stride_o = dk * query_sequence_length; - const int swizzle_log = 0; - const int gemm_n_iterations_aligned = (N + bn - 1) / bn; - const int gemm_k_iterations_aligned = (K + bk - 1) / bk; - const int gemm_sv_m_block_iterations = (M + bm - 1) / bm; - const int batch_ndim = int(batch_shape.size()); + /* int NQ = */ NQ, + /* int NK = */ NK, - MLXFastAttentionParams params{ - (int)M, - (int)N, - (int)K, - ldq, - ldk, - ldv, - lds, - ldo, - tn, - tm, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_o, - swizzle_log, - gemm_n_iterations_aligned, - gemm_k_iterations_aligned, - gemm_sv_m_block_iterations, - batch_ndim, - alpha}; + /* int NQ_aligned = */ NQ_aligned, + /* int NK_aligned = */ NK_aligned, - const std::vector batch_strides = { - (size_t)batch_stride_q, - (size_t)batch_stride_k, - (size_t)batch_stride_v, - (size_t)batch_stride_o}; + /* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, + /* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, + /* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, + /* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); - compute_encoder.set_output_array(out, 3); - + compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); + MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -356,7 +325,24 @@ void ScaledDotProductAttention::eval_gpu( const auto& q = copy_unless(is_matrix_contiguous, q_pre); const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); - o.set_data(allocator::malloc_or_wait(o.nbytes())); + + size_t str_oD = 1; + size_t str_oH = o.shape(3); + size_t str_oL = o.shape(1) * str_oH; + size_t str_oB = o.shape(2) * str_oL; + size_t data_size = o.shape(0) * str_oB; + + array::Flags flags{ + /* bool contiguous = */ 1, + /* bool row_contiguous = */ 0, + /* bool col_contiguous = */ 0, + }; + + o.set_data( + allocator::malloc_or_wait(o.nbytes()), + data_size, + {str_oB, str_oH, str_oL, str_oD}, + flags); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ccce1d6b1..731912d69 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -600,7 +600,7 @@ array scaled_dot_product_attention( * * dtype is not fp32 or fp16 */ - int threshold = 1e6; + int threshold = 32; // TODO: Fix after dev if (memory_efficient_threshold.has_value()) { threshold = std::max(1, memory_efficient_threshold.value()); } @@ -644,11 +644,10 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; const bool sdpa_full_supported_head_dim = - query_head_dim == 64 || query_head_dim == 128; + query_head_dim == 64 || query_head_dim == 80; const bool supports_sdpa_full = query_sequence_length >= threshold && !mask.has_value() && sdpa_full_supported_head_dim && - n_q_heads == n_kv_heads && final_type != bfloat16 && stream.device == Device::gpu; const bool supports_sdpa_vector = query_sequence_length == 1 &&