diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 978475c53..936127e20 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -68,6 +68,22 @@ set(STEEL_HEADERS steel/utils/type_traits.h steel/utils/integral_constant.h) +set(STEEL_ATTN_HEADERS + steel/defines.h + steel/utils.h + steel/gemm/gemm.h + steel/gemm/mma.h + steel/gemm/loader.h + steel/gemm/transforms.h + steel/utils/type_traits.h + steel/utils/integral_constant.h + steel/attn/attn.h + steel/attn/loader.h + steel/attn/mma.h + steel/attn/params.h + steel/attn/transforms.h + steel/attn/kernels/steel_attention.h) + if(NOT MLX_METAL_JIT) build_kernel(arange arange.h) build_kernel(binary binary.h binary_ops.h) @@ -93,6 +109,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) + build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS}) build_kernel(gemv_masked steel/utils.h) endif() diff --git a/mlx/backend/metal/kernels/steel/attn/attn.h b/mlx/backend/metal/kernels/steel/attn/attn.h new file mode 100644 index 000000000..991d4d69c --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/attn.h @@ -0,0 +1,296 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/attn/loader.h" +#include "mlx/backend/metal/kernels/steel/attn/mma.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h new file mode 100644 index 000000000..61d1b9563 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -0,0 +1,65 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +constant bool do_gather [[function_constant(300)]]; + +constant bool gather_bias = do_gather && use_out_source; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Seqeunce + + K += tidl.z * params->K_strides[0] + // Batch + tidl.y * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + tidl.y * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Seqeunce + + for (int i = simd_group_id * 32 + simd_lane_id; i < BQ * params->D; + i += WM * WN * 32) { + int r = i / params->D; + int c = i % params->D; + + O[params->O_strides[2] * r + c] = T(0); + } +} diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal new file mode 100644 index 000000000..bbfecd7bf --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -0,0 +1,31 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/attn/attn.h" +#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \ + template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \ + [[kernel]] void attention( \ + const device dtype* Q [[buffer(0)]], \ + const device dtype* K [[buffer(1)]], \ + const device dtype* V [[buffer(2)]], \ + device dtype* O [[buffer(3)]],\ + const constant AttnParams* params [[buffer(4)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_attn_shapes_helper(iname, itype) \ + instantiate_attn(iname, itype, 16, 16, 64, 2, 2) \ + + +instantiate_attn_shapes_helper(float16, half); +instantiate_attn_shapes_helper(bfloat16, bfloat16_t); + +instantiate_attn_shapes_helper(float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h new file mode 100644 index 000000000..3f084d8ec --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -0,0 +1,137 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h new file mode 100644 index 000000000..ad7a8c529 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -0,0 +1,668 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h new file mode 100644 index 000000000..d460c523d --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -0,0 +1,31 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Attn param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int L; ///< Sequence Length + int D; ///< Head Dim + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + size_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + size_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + size_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/transforms.h b/mlx/backend/metal/kernels/steel/attn/transforms.h new file mode 100644 index 000000000..c0624d21b --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/transforms.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 3edda16a3..cfb57da93 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -7,6 +7,9 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/utils.h" @@ -19,122 +22,57 @@ void sdpa_full_self_attention_metal( const array& q, const array& k, const array& v, - const float alpha, - array& out) { - std::ostringstream kname_self_attention; - kname_self_attention << "steel_gemm_attention_"; + const float scale, + array& o) { + using namespace mlx::steel; - constexpr const int bm = 16; - constexpr const int bn = 16; - const int bk = q.shape(-1); // already forced to be 64 or 128 + int wm = 2; + int wn = 2; - if (bk != 64 && bk != 128) { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128"); - } + int bq = 16; + int bk = 16; + int bd = 64; - constexpr const int wm = 2; - constexpr const int wn = 2; - - std::string delimiter = "_"; - - kname_self_attention << "bm_" + std::to_string(bm) + delimiter; - kname_self_attention << "bn_" + std::to_string(bn) + delimiter; - kname_self_attention << "bk_" + std::to_string(bk) + delimiter; - - for (const auto& arr : {k, v, out}) { - if (arr.dtype() != q.dtype()) { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o"); - } - } - - if (q.dtype() == float32) { - kname_self_attention << "itype" + delimiter + "float"; - } else if (q.dtype() == float16) { - kname_self_attention << "itype" + delimiter + "half"; - } else { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16."); - } + std::ostringstream kname; + kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk + << "_bd" << bd << "_wm" << wm << "_wn" << wn; auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname_self_attention.str()); + auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); - uint hidden_dim = q.shape(-1); - uint qseq = q.shape(-2); - uint qheads = q.shape(-3); + int B = q.shape(0); + int H = q.shape(1); + int L = q.shape(2); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); - const uint64_t KV_sequence_length = k.shape(-2); - const uint query_sequence_length = q.shape(-2); - const uint n_q_heads = q.shape(1); - const uint n_kv_heads = k.shape(1); + int NQ = (L + bq - 1) / bq; + int NK = (L + bk - 1) / bk; - const int M = q.shape(-2); - const int N = M; - const int K = q.shape(-1); - const size_t batch_size_out = q.shape(0) * q.shape(1); + AttnParams params{ + /* int B = */ B, + /* int H = */ H, + /* int L = */ L, + /* int D = */ D, + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, - const std::vector batch_shape = {q.shape(0) * q.shape(1)}; - const int dk = q.shape(-1); - const int ldq = dk; - const int ldk = dk; - const int ldv = dk; - const int lds = bn; - const int ldo = dk; + /* int NQ = */ NQ, + /* int NK = */ NK, - int tn = 1; - int tm = (M + bm - 1) / bm; - - const int batch_stride_q = dk * query_sequence_length; - const int batch_stride_k = dk * query_sequence_length; - const int batch_stride_v = dk * query_sequence_length; - const int batch_stride_o = dk * query_sequence_length; - const int swizzle_log = 0; - const int gemm_n_iterations_aligned = (N + bn - 1) / bn; - const int gemm_k_iterations_aligned = (K + bk - 1) / bk; - const int gemm_sv_m_block_iterations = (M + bm - 1) / bm; - const int batch_ndim = int(batch_shape.size()); - - MLXFastAttentionParams params{ - (int)M, - (int)N, - (int)K, - ldq, - ldk, - ldv, - lds, - ldo, - tn, - tm, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_o, - swizzle_log, - gemm_n_iterations_aligned, - gemm_k_iterations_aligned, - gemm_sv_m_block_iterations, - batch_ndim, - alpha}; - - const std::vector batch_strides = { - (size_t)batch_stride_q, - (size_t)batch_stride_k, - (size_t)batch_stride_v, - (size_t)batch_stride_o}; + /* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, + /* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, + /* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, + /* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); - compute_encoder.set_output_array(out, 3); - + compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); + MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index d3eb77d06..d27120cac 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -600,7 +600,7 @@ array scaled_dot_product_attention( * * dtype is not fp32 or fp16 */ - int threshold = 1e6; + int threshold = 1024; // TODO: Fix after dev if (memory_efficient_threshold.has_value()) { threshold = std::max(1, memory_efficient_threshold.value()); }