Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-11-21 13:28:15 -08:00
committed by GitHub
parent 0d68efd461
commit 0dbc7e5bee
6 changed files with 135 additions and 191 deletions

View File

@@ -121,14 +121,6 @@ if(NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif()
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(MLX_ENABLE_NAX TRUE)
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
else()
set(MLX_ENABLE_NAX FALSE)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(mlx

View File

@@ -265,14 +265,19 @@ Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
#ifdef MLX_ENABLE_NAX
inline bool is_nax_available() {
static bool is_nax_available_ =
auto _check_nax = []() {
bool can_use_nax = false;
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
can_use_nax = true;
}
can_use_nax &=
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return can_use_nax;
};
static bool is_nax_available_ = _check_nax();
return is_nax_available_;
}
#endif // MLX_ENABLE_NAX
} // namespace mlx::core::metal

View File

@@ -9,13 +9,17 @@ set(BASE_HEADERS
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
set(METAL_FLAGS
-x
metal
-Wall
-Wextra
-fno-fast-math
-Wno-c++17-extensions
-Wno-c++20-extensions)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(MLX_ENABLE_NAX)
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
@@ -123,8 +127,8 @@ if(NOT MLX_METAL_JIT)
build_kernel(gemv_masked steel/utils.h)
endif()
if(MLX_ENABLE_NAX)
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(STEEL_NAX_HEADERS
steel/defines.h
steel/utils.h

View File

@@ -172,8 +172,6 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
// Regular steel matmul dispatch
///////////////////////////////////////////////////////////////////////////////
#ifdef MLX_ENABLE_NAX
template <bool CHECK_AB>
void steel_matmul_regular_axpby_nax(
const Stream& s,
@@ -329,8 +327,6 @@ void steel_matmul_regular_axpby_nax(
d.add_temporaries(std::move(copies), s.index);
}
#endif // MLX_ENABLE_NAX
template <bool CHECK_AB>
void steel_matmul_regular_axpby(
const Stream& s,
@@ -357,9 +353,6 @@ void steel_matmul_regular_axpby(
int64_t C_batch_stride /* = 0*/,
float alpha /* = 1.0f */,
float beta /* = 0.0f */) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return steel_matmul_regular_axpby_nax<CHECK_AB>(
@@ -388,9 +381,6 @@ void steel_matmul_regular_axpby(
/* float alpha = */ alpha,
/* float beta = */ beta);
}
}
#endif // MLX_ENABLE_NAX
using namespace mlx::steel;
@@ -1766,8 +1756,6 @@ void gather_mm_rhs(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void gather_mm_rhs_nax(
const array& a_,
const array& b_,
@@ -1911,8 +1899,6 @@ void gather_mm_rhs_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_mv(
const array& mat_,
const array& vec_,
@@ -2196,19 +2182,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() &&
!issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
}
}
#endif // MLX_ENABLE_NAX
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}

View File

@@ -451,8 +451,6 @@ void qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void qmm_nax(
const array& x,
const array& w,
@@ -653,8 +651,6 @@ void gather_qmm_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void qmm(
const array& x,
const array& w,
@@ -670,9 +666,6 @@ void qmm(
metal::Device& d,
const Stream& s,
const std::string& mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return qmm_nax(
@@ -691,9 +684,6 @@ void qmm(
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N;
@@ -772,9 +762,6 @@ void gather_qmm(
metal::Device& d,
const Stream& s,
const std::string& mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return gather_qmm_nax(
@@ -795,9 +782,6 @@ void gather_qmm(
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N;
@@ -975,8 +959,6 @@ void gather_qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void gather_qmm_rhs_nax(
const array& x_,
const array& w_,
@@ -1108,8 +1090,6 @@ void gather_qmm_rhs_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_qmm_rhs(
const array& x_,
const array& w_,
@@ -1126,9 +1106,6 @@ void gather_qmm_rhs(
metal::Device& d,
const Stream& s,
const std::string mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose &&
(env::enable_tf32() || x_.dtype() != float32)) {
return gather_qmm_rhs_nax(
@@ -1148,9 +1125,6 @@ void gather_qmm_rhs(
/* const Stream& s = */ s,
/* const std::string mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);

View File

@@ -13,8 +13,6 @@ namespace mlx::core::fast {
namespace {
#ifdef MLX_ENABLE_NAX
void sdpa_full_self_attention_nax(
const Stream& s,
metal::Device& d,
@@ -150,8 +148,6 @@ void sdpa_full_self_attention_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void sdpa_full_self_attention_metal(
const Stream& s,
metal::Device& d,
@@ -163,8 +159,6 @@ void sdpa_full_self_attention_metal(
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
(env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax(
@@ -179,8 +173,6 @@ void sdpa_full_self_attention_metal(
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
}
#endif // MLX_ENABLE_NAX
using namespace mlx::steel;