From 0dbc7e5bee44c0a32264d931a7a6e13ec7392ab6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 21 Nov 2025 13:28:15 -0800 Subject: [PATCH] Centralize NAX condition (#2811) --- mlx/backend/metal/CMakeLists.txt | 8 -- mlx/backend/metal/device.h | 17 ++- mlx/backend/metal/kernels/CMakeLists.txt | 16 ++- mlx/backend/metal/matmul.cpp | 117 ++++++--------- mlx/backend/metal/quantized.cpp | 134 +++++++----------- .../metal/scaled_dot_product_attention.cpp | 34 ++--- 6 files changed, 135 insertions(+), 191 deletions(-) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 111975709..3cfd0b22b 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -121,14 +121,6 @@ if(NOT MLX_METAL_PATH) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) endif() -if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL - 26.2)) - set(MLX_ENABLE_NAX TRUE) - target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX) -else() - set(MLX_ENABLE_NAX FALSE) -endif() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) target_compile_definitions(mlx diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 746fbc088..564d15a9b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -265,14 +265,19 @@ Device& device(mlx::core::Device); std::unique_ptr> new_scoped_memory_pool(); -#ifdef MLX_ENABLE_NAX - inline bool is_nax_available() { - static bool is_nax_available_ = - metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; + auto _check_nax = []() { + bool can_use_nax = false; + if (__builtin_available( + macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + can_use_nax = true; + } + can_use_nax &= + metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; + return can_use_nax; + }; + static bool is_nax_available_ = _check_nax(); return is_nax_available_; } -#endif // MLX_ENABLE_NAX - } // namespace mlx::core::metal diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 5215fb346..bfbbfd799 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -9,13 +9,17 @@ set(BASE_HEADERS utils.h) function(build_kernel_base TARGET SRCFILE DEPS) - set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) + set(METAL_FLAGS + -x + metal + -Wall + -Wextra + -fno-fast-math + -Wno-c++17-extensions + -Wno-c++20-extensions) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() - if(MLX_ENABLE_NAX) - set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0) - endif() if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") set(METAL_FLAGS ${METAL_FLAGS} "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") @@ -123,8 +127,8 @@ if(NOT MLX_METAL_JIT) build_kernel(gemv_masked steel/utils.h) endif() -if(MLX_ENABLE_NAX) - +if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL + 26.2)) set(STEEL_NAX_HEADERS steel/defines.h steel/utils.h diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index e4f625383..add11c146 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -172,8 +172,6 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { // Regular steel matmul dispatch /////////////////////////////////////////////////////////////////////////////// -#ifdef MLX_ENABLE_NAX - template void steel_matmul_regular_axpby_nax( const Stream& s, @@ -210,11 +208,11 @@ void steel_matmul_regular_axpby_nax( std::ostringstream kname; // clang-format off - kname << "steel_gemm_fused_nax_" + kname << "steel_gemm_fused_nax_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') - << "_" << type_to_name(a) - << "_" << type_to_name(out) + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // clang-format on @@ -329,8 +327,6 @@ void steel_matmul_regular_axpby_nax( d.add_temporaries(std::move(copies), s.index); } -#endif // MLX_ENABLE_NAX - template void steel_matmul_regular_axpby( const Stream& s, @@ -357,41 +353,35 @@ void steel_matmul_regular_axpby( int64_t C_batch_stride /* = 0*/, float alpha /* = 1.0f */, float beta /* = 0.0f */) { -#ifdef MLX_ENABLE_NAX - - if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && - (env::enable_tf32() || a.dtype() != float32)) { - return steel_matmul_regular_axpby_nax( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& a = */ a, - /* const array& b = */ b, - /* const array& c = */ c, - /* array& out = */ out, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* int batch_size_out = */ batch_size_out, - /* int lda = */ lda, - /* int ldb = */ ldb, - /* int ldd = */ ldd, - /* bool transpose_a = */ transpose_a, - /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides batch_strides = */ batch_strides, - /* int64_t A_batch_stride = */ A_batch_stride, - /* int64_t B_batch_stride = */ B_batch_stride, - /* int64_t matrix_stride_out = */ matrix_stride_out, - /* int64_t C_batch_stride = */ C_batch_stride, - /* float alpha = */ alpha, - /* float beta = */ beta); - } + if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && + (env::enable_tf32() || a.dtype() != float32)) { + return steel_matmul_regular_axpby_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out, + /* int64_t C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha, + /* float beta = */ beta); } -#endif // MLX_ENABLE_NAX - using namespace mlx::steel; // Determine dispatch kernel @@ -405,11 +395,11 @@ void steel_matmul_regular_axpby( std::ostringstream kname; // clang-format off - kname << "steel_gemm_fused_" + kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') - << "_" << type_to_name(a) - << "_" << type_to_name(out) + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // clang-format on @@ -574,14 +564,14 @@ void steel_gemm_splitk_axpby( std::ostringstream kname; // clang-format off - kname << "steel_gemm_splitk_" + kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') - << "_" << type_to_name(a) - << "_" << type_to_name(C_split) + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn - << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on // Encode and dispatch gemm kernel @@ -915,10 +905,10 @@ void gemv_axbpy( const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); // clang-format off - kname << "_bm" << bm << "_bn" << bn - << "_sm" << sm << "_sn" << sn + kname << "_bm" << bm << "_bn" << bn + << "_sm" << sm << "_sn" << sn << "_tm" << tm << "_tn" << tn - << "_nc" << !contiguous_kernel + << "_nc" << !contiguous_kernel << "_axpby" << do_axpby; // clang-format on // Encode and dispatch kernel @@ -1766,8 +1756,6 @@ void gather_mm_rhs( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#ifdef MLX_ENABLE_NAX - void gather_mm_rhs_nax( const array& a_, const array& b_, @@ -1911,8 +1899,6 @@ void gather_mm_rhs_nax( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#endif // MLX_ENABLE_NAX - void gather_mv( const array& mat_, const array& vec_, @@ -2196,19 +2182,10 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // We are walking a in order and b is also in order so we can batch up the // matmuls and reuse reading a and b. if (M == 1 && right_sorted_ == true) { -#ifdef MLX_ENABLE_NAX - - if (__builtin_available( - macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && - !issubdtype(a.dtype(), complexfloating) && - (env::enable_tf32() || a.dtype() != float32)) { - return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); - } + if (metal::is_nax_available() && + (env::enable_tf32() || a.dtype() != float32)) { + return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); } - -#endif // MLX_ENABLE_NAX - gather_mm_rhs(a, b, rhs_indices, out, d, s); return; } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 55b69b9ca..f28570618 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -451,8 +451,6 @@ void qvm( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#ifdef MLX_ENABLE_NAX - void qmm_nax( const array& x, const array& w, @@ -653,8 +651,6 @@ void gather_qmm_nax( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#endif // MLX_ENABLE_NAX - void qmm( const array& x, const array& w, @@ -670,31 +666,25 @@ void qmm( metal::Device& d, const Stream& s, const std::string& mode) { -#ifdef MLX_ENABLE_NAX - - if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && transpose && (K % 64 == 0) && - (env::enable_tf32() || x.dtype() != float32)) { - return qmm_nax( - /* const array& x = */ x, - /* const array& w = */ w, - /* const array& scales = */ scales, - /* const std::optional& biases = */ biases, - /* array& out = */ out, - /* bool transpose = */ transpose, - /* int group_size = */ group_size, - /* int bits = */ bits, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* metal::Device& d = */ d, - /* const Stream& s = */ s, - /* const std::string& mode = */ mode); - } + if (metal::is_nax_available() && transpose && (K % 64 == 0) && + (env::enable_tf32() || x.dtype() != float32)) { + return qmm_nax( + /* const array& x = */ x, + /* const array& w = */ w, + /* const array& scales = */ scales, + /* const std::optional& biases = */ biases, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string& mode = */ mode); } -#endif // MLX_ENABLE_NAX - int B = out.size() / M / N; int wm = 2; @@ -772,33 +762,27 @@ void gather_qmm( metal::Device& d, const Stream& s, const std::string& mode) { -#ifdef MLX_ENABLE_NAX - - if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && transpose && (K % 64 == 0) && - (env::enable_tf32() || x.dtype() != float32)) { - return gather_qmm_nax( - /* const array& x = */ x, - /* const array& w = */ w, - /* const array& scales = */ scales, - /* const std::optional& biases = */ biases, - /* const array& lhs_indices = */ lhs_indices, - /* const array& rhs_indices = */ rhs_indices, - /* array& out = */ out, - /* bool transpose = */ transpose, - /* int group_size = */ group_size, - /* int bits = */ bits, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* metal::Device& d = */ d, - /* const Stream& s = */ s, - /* const std::string& mode = */ mode); - } + if (metal::is_nax_available() && transpose && (K % 64 == 0) && + (env::enable_tf32() || x.dtype() != float32)) { + return gather_qmm_nax( + /* const array& x = */ x, + /* const array& w = */ w, + /* const array& scales = */ scales, + /* const std::optional& biases = */ biases, + /* const array& lhs_indices = */ lhs_indices, + /* const array& rhs_indices = */ rhs_indices, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string& mode = */ mode); } -#endif // MLX_ENABLE_NAX - int B = out.size() / M / N; int wm = 2; @@ -975,8 +959,6 @@ void gather_qvm( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#ifdef MLX_ENABLE_NAX - void gather_qmm_rhs_nax( const array& x_, const array& w_, @@ -1108,8 +1090,6 @@ void gather_qmm_rhs_nax( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#endif // MLX_ENABLE_NAX - void gather_qmm_rhs( const array& x_, const array& w_, @@ -1126,32 +1106,26 @@ void gather_qmm_rhs( metal::Device& d, const Stream& s, const std::string mode) { -#ifdef MLX_ENABLE_NAX - - if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && transpose && - (env::enable_tf32() || x_.dtype() != float32)) { - return gather_qmm_rhs_nax( - /* const array& x_ = */ x_, - /* const array& w_ = */ w_, - /* const array& scales_ = */ scales_, - /* const std::optional& biases_ = */ biases_, - /* const array& indices_ = */ indices_, - /* array& out = */ out, - /* bool transpose = */ transpose, - /* int group_size = */ group_size, - /* int bits = */ bits, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* metal::Device& d = */ d, - /* const Stream& s = */ s, - /* const std::string mode = */ mode); - } + if (metal::is_nax_available() && transpose && + (env::enable_tf32() || x_.dtype() != float32)) { + return gather_qmm_rhs_nax( + /* const array& x_ = */ x_, + /* const array& w_ = */ w_, + /* const array& scales_ = */ scales_, + /* const std::optional& biases_ = */ biases_, + /* const array& indices_ = */ indices_, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string mode = */ mode); } -#endif // MLX_ENABLE_NAX - // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d3920b55d..731001a15 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -13,8 +13,6 @@ namespace mlx::core::fast { namespace { -#ifdef MLX_ENABLE_NAX - void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, @@ -150,8 +148,6 @@ void sdpa_full_self_attention_nax( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -#endif // MLX_ENABLE_NAX - void sdpa_full_self_attention_metal( const Stream& s, metal::Device& d, @@ -163,24 +159,20 @@ void sdpa_full_self_attention_metal( bool do_causal_, const std::optional& mask, const std::optional& sinks) { -#ifdef MLX_ENABLE_NAX - if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && q.shape(3) != 80 && - (env::enable_tf32() || q.dtype() != float32)) { - return sdpa_full_self_attention_nax( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& q = */ q, - /* const array& k = */ k, - /* const array& v = */ v, - /* const float scale = */ scale, - /* array& o = */ o, - /* bool do_causal_ = */ do_causal_, - /* const std::optional& mask = */ mask, - /* const std::optional& sinks = */ sinks); - } + if (metal::is_nax_available() && q.shape(3) != 80 && + (env::enable_tf32() || q.dtype() != float32)) { + return sdpa_full_self_attention_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& q = */ q, + /* const array& k = */ k, + /* const array& v = */ v, + /* const float scale = */ scale, + /* array& o = */ o, + /* bool do_causal_ = */ do_causal_, + /* const std::optional& mask = */ mask, + /* const std::optional& sinks = */ sinks); } -#endif // MLX_ENABLE_NAX using namespace mlx::steel;