diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 2baa6c05b..7a0a167d7 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -120,6 +120,14 @@ if(NOT MLX_METAL_PATH) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) endif() +if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL + 26.2)) + set(MLX_ENABLE_NAX TRUE) + target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX) +else() + set(MLX_ENABLE_NAX FALSE) +endif() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) target_compile_definitions(mlx diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index fefb7cdc0..0b08607be 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -265,4 +265,15 @@ Device& device(mlx::core::Device); std::unique_ptr> new_scoped_memory_pool(); +#ifdef MLX_ENABLE_NAX + +inline bool is_nax_available() { + static bool is_nax_available_ = + /* __builtin_available(macOS 26.2, *) && */ + metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; + return is_nax_available_; +} + +#endif // MLX_ENABLE_NAX + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index c2842d534..3f04f5086 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -9,10 +9,13 @@ set(BASE_HEADERS utils.h) function(build_kernel_base TARGET SRCFILE DEPS) - set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) + set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() + if(MLX_ENABLE_NAX) + set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0) + endif() if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") set(METAL_FLAGS ${METAL_FLAGS} "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") @@ -120,6 +123,22 @@ if(NOT MLX_METAL_JIT) build_kernel(gemv_masked steel/utils.h) endif() +if(MLX_ENABLE_NAX) + + set(STEEL_NAX_HEADERS + steel/defines.h + steel/utils.h + steel/gemm/transforms.h + steel/gemm/nax.h + steel/gemm/gemm_nax.h + steel/utils/type_traits.h + steel/utils/integral_constant.h) + + build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) + +endif() + add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o diff --git a/mlx/backend/metal/kernels/steel/defines.h b/mlx/backend/metal/kernels/steel/defines.h index 6c3bfcf4e..f5657ee36 100644 --- a/mlx/backend/metal/kernels/steel/defines.h +++ b/mlx/backend/metal/kernels/steel/defines.h @@ -1,4 +1,7 @@ // Copyright © 2024 Apple Inc. +#pragma once + #define STEEL_CONST static constant constexpr const #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") +#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h new file mode 100644 index 000000000..3cd20d7b9 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h @@ -0,0 +1,154 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" + +using namespace metal; + +namespace mlx::steel { + +template < + typename T, + short SM, + short SN, + short SK, + short BK, + bool transpose_a, + bool transpose_b, + bool kAlignedM, + bool kAlignedN, + bool kAlignedK, + short UM, + short UN, + short UK, + typename AccumType = float> +auto gemm_loop( + const device T* A, + const device T* B, + const constant GEMMParams* params [[buffer(4)]], + const short sgp_sm, + const short sgp_sn) { + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + constexpr int RA = transpose_a ? TK : TM; + constexpr int CA = transpose_a ? TM : TK; + + constexpr int RB = transpose_b ? TN : TK; + constexpr int CB = transpose_b ? TK : TN; + + using DSubTile = NAXSubTile; + using ASubTile = + NAXSubTile; + using BSubTile = + NAXSubTile; + + NAXTile Dtile; + Dtile.clear(); + + int gemm_k_iterations_ = params->gemm_k_iterations_aligned; + + STEEL_PRAGMA_NO_UNROLL + for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) { + threadgroup_barrier(mem_flags::mem_none); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + const int k = kk1; + + volatile int compiler_barrier; + + const int A_offset = transpose_a ? k * params->lda : k; + const int B_offset = transpose_b ? k : k * params->ldb; + + if constexpr (kAlignedM) { + Atile.load(A + A_offset, params->lda); + } else { + const short rmax = transpose_a ? UK : sgp_sm; + const short cmax = transpose_a ? sgp_sm : UK; + Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); + } + + if constexpr (kAlignedN) { + Btile.load(B + B_offset, params->ldb); + } else { + const short rmax = transpose_b ? sgp_sn : UK; + const short cmax = transpose_b ? UK : sgp_sn; + Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + A += transpose_a ? (BK * params->lda) : BK; + B += transpose_b ? BK : (BK * params->ldb); + } + + if constexpr (!kAlignedK) { + simdgroup_barrier(mem_flags::mem_none); + + const short rem_bk = params->K - gemm_k_iterations_ * BK; + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + STEEL_PRAGMA_UNROLL + for (int mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (int nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (int kk = 0; kk < TK; kk++) { + const int m = mm * UM; + const int n = nn * UN; + const int k = kk1 + kk * UK; + const short psk = max(0, rem_bk - k); + + const int A_offset = + transpose_a ? (m + k * params->lda) : (m * params->lda + k); + const int B_offset = + transpose_b ? (k + n * params->ldb) : (k * params->ldb + n); + + { + const short psm = kAlignedM ? SM : max(0, sgp_sm - m); + const short rmax = transpose_a ? psk : psm; + const short cmax = transpose_a ? psm : psk; + Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); + } + + { + const short psn = kAlignedN ? SN : max(0, sgp_sn - n); + const short rmax = transpose_b ? psn : psk; + const short cmax = transpose_b ? psk : psn; + Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); + } + + subtile_matmad_nax( + Dtile.subtile_at(mm, nn), + Atile.subtile_at(0, 0), + metal::bool_constant{}, + Btile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + } + } + + return Dtile; +} + +} // namespace mlx::steel diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h new file mode 100644 index 000000000..44328ed0b --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h @@ -0,0 +1,207 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +// clang-format off +template < + bool kAlignedM, + bool kAlignedN, + typename NAXTile_t, + typename T> +void gemm_epilogue( + thread NAXTile_t& Dtile, + const device T* C, + const constant GEMMParams* params, + const constant GEMMAddMMParams* addmm_params, + const short sgp_sm, + const short sgp_sn) { // clang-format on + + (void)params; + + constexpr short UM = NAXTile_t::kSubTileRows; + constexpr short UN = NAXTile_t::kSubTileCols; + using CSubTile = NAXSubTile; + + using V = typename NAXTile_t::elem_type; + + constexpr short TM = NAXTile_t::kTileRows; + constexpr short TN = NAXTile_t::kTileCols; + constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile; + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + const short m = mm * UM; + const short n = nn * UN; + + CSubTile CTile; + + if constexpr (kAlignedM && kAlignedN) { + CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n); + } else { + CTile.load_safe( + C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); + } + + auto delems = Dtile.subtile_at(mm, nn).elems(); + auto celems = CTile.elems(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemsPerSubTile; i++) { + if (do_axpby) { + delems[i] = addmm_params->alpha * delems[i] + + addmm_params->beta * static_cast(celems[i]); + } else { + delems[i] += static_cast(celems[i]); + } + } + } + } +} + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn))); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + D += tm * params->ldd + tn; + + if (use_out_source) { + C += tm * addmm_params->ldc + tn * addmm_params->fdc; + } + + using DSubTile = NAXSubTile; + NAXTile Dtile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + Dtile = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>(A, B, params, sgp_sm, sgp_sn); + if (use_out_source) { + gemm_epilogue( + Dtile, C, params, addmm_params, sgp_sm, sgp_sn); + } + if constexpr (kAlignedM && kAlignedN) { + Dtile.store(D, int(params->ldd)); + } else { + Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm)); + } + }); + }); + }); +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal new file mode 100644 index 000000000..e6cb0b64c --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal @@ -0,0 +1,35 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h" + +// clang-format off +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ + gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); +instantiate_gemm_shapes_helper(float32, float, float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h new file mode 100644 index 000000000..29285833a --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h @@ -0,0 +1,132 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +gather_mm_rhs_nax( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + rhs_indices += c_row; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn))); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + C += tm * params->ldd + tn; + rhs_indices += tm; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[0]; + short offset_next = 0; + int n = 0; + while (n < sgp_sm) { + n++; + offset = offset_next; + index = index_next; + offset_next = sgp_sm; + for (; n < sgp_sm; n++) { + if (rhs_indices[n] != index) { + offset_next = n; + index_next = rhs_indices[n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + using DSubTile = NAXSubTile; + NAXTile Ctile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + auto do_gemm = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>; + Ctile = do_gemm( + A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn); + + if constexpr (kAlignedN.value) { + if (offset_next - offset == SM) { + Ctile.store(C, int(params->ldd)); + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(SN, offset_next)); + } + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(sgp_sn, offset_next)); + } + }); + }); + }); + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal new file mode 100644 index 000000000..5b8589f54 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/utils.h" + +// clang-format off +#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm_rhs_nax, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4) +// clang-format on + +instantiate_gather_mm_shapes_helper(float16, half, float16, half); +instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); diff --git a/mlx/backend/metal/kernels/steel/gemm/nax.h b/mlx/backend/metal/kernels/steel/gemm/nax.h new file mode 100644 index 000000000..bc57a1657 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -0,0 +1,1087 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_t = BaseNAXFrag> +struct NAXSubTile { + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + // Create Matmul descriptor + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + // Load B into right operand registers + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + // Load C into output registers (op handles accumulation) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); + + // Copy out results + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/utils/integral_constant.h b/mlx/backend/metal/kernels/steel/utils/integral_constant.h index b616acc67..526f561ee 100644 --- a/mlx/backend/metal/kernels/steel/utils/integral_constant.h +++ b/mlx/backend/metal/kernels/steel/utils/integral_constant.h @@ -74,6 +74,44 @@ integral_const_binop(>=, operator>=); integral_const_binop(&&, operator&&); integral_const_binop(||, operator||); +template >> +METAL_FUNC constexpr auto operator||(true_type, T) { + return true_type{}; +} +template >> +METAL_FUNC constexpr auto operator||(T, true_type) { + return true_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(false_type, T) { + return false_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(T, false_type) { + return false_type{}; +} + +// Dispatch utilities +template +void dispatch_bool(bool v, F f) { + if (v) { + f(true_type{}); + } else { + f(false_type{}); + } +} + +template +constexpr void const_for_loop(F f) { + if constexpr (start < stop) { + constexpr auto idx = Int{}; + f(idx); + const_for_loop(f); + } +} + #undef integral_const_binop /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index abc45575a..540fbce02 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -172,6 +172,165 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { // Regular steel matmul dispatch /////////////////////////////////////////////////////////////////////////////// +#ifdef MLX_ENABLE_NAX + +template +void steel_matmul_regular_axpby_nax( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { + using namespace mlx::steel; + + // Determine dispatch kernel + int bm = 128, bn = 128, bk = 512; + int wm = 4, wn = 4; + + // Prepare kernel name + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_fused_nax_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on + + std::string base_name = kname.str(); + + const bool has_batch = (batch_shape.size() > 1); + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_fused_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); + + compute_encoder.set_compute_pipeline_state(kernel); + + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = tm <= 3 ? 0 : 1; + + // Prepare steel matmul params + GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ ldd, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int64_t batch_stride_a = */ A_batch_stride, + /* const int64_t batch_stride_b = */ B_batch_stride, + /* const int64_t batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ int(batch_shape.size())}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(params, 4); + + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } + + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Record copies + d.add_temporaries(std::move(copies), s.index); +} + +#endif // MLX_ENABLE_NAX + template void steel_matmul_regular_axpby( const Stream& s, @@ -198,6 +357,39 @@ void steel_matmul_regular_axpby( int64_t C_batch_stride /* = 0*/, float alpha /* = 1.0f */, float beta /* = 0.0f */) { +#ifdef MLX_ENABLE_NAX + + if (metal::is_nax_available() && + (a.dtype() != float32 || env::enable_tf32())) { + return steel_matmul_regular_axpby_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out, + /* int64_t C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha, + /* float beta = */ beta); + } + +#endif // MLX_ENABLE_NAX + using namespace mlx::steel; // Determine dispatch kernel @@ -1572,6 +1764,153 @@ void gather_mm_rhs( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +#ifdef MLX_ENABLE_NAX + +void gather_mm_rhs_nax( + const array& a_, + const array& b_, + const array& indices_, + array& out, + metal::Device& d, + const Stream& s) { + array indices = ensure_row_contiguous(indices_, d, s); + auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); + + // Broadcast a with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of a broadcasted + // with rhs_indices. We need only broadcast a and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + array a = broadcast_with_indices(a_); + + // Extract the matmul shapes + int K = a.shape(-1); + int M = a.size() / K; + int N = b.shape(-1); + int lda = a.strides()[a.ndim() - 2]; // should be K + int E = b.shape(0); + + // Define the dispatch blocks + int bm, bn = 128, bk = 128, wm, wn = 4; + if (M / E > 48) { + bm = 64; + wm = 2; + } else if (M / E > 24) { + bm = 32l; + wm = 1; + } else { + bm = 16; + wm = 1; + } + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Define the kernel name + std::string base_name; + base_name.reserve(64); + concatenate( + base_name, + "steel_gather_mm_rhs_mxu_n", + transpose_b ? 't' : 'n', + '_', + type_to_name(a), + '_', + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_gather_kernel( + d, + base_name, + hash_name, + func_consts, + out, + false, + transpose_b, + bm, + bn, + bk, + wm, + wn, + true); + compute_encoder.set_compute_pipeline_state(kernel); + + // Prepare the matmul params + auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), + /* const int64_t batch_stride_d = */ 0, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ 0}; + + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(indices, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void gather_mv( const array& mat_, const array& vec_, @@ -1855,6 +2194,14 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // We are walking a in order and b is also in order so we can batch up the // matmuls and reuse reading a and b. if (M == 1 && right_sorted_ == true) { +#ifdef MLX_ENABLE_NAX + + if (metal::is_nax_available() && a.dtype() != float32) { + return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); + } + +#endif // MLX_ENABLE_NAX + gather_mm_rhs(a, b, rhs_indices, out, d, s); return; }