From 56c8a33439e8c0aa158a76e4c719cd89f6d70ba5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 26 Jun 2024 07:02:11 -0700 Subject: [PATCH] Get metal version from xcode (#1228) * get metal version from xcode * typo * fix --- CMakeLists.txt | 19 +++++++++---------- mlx/backend/metal/CMakeLists.txt | 2 +- mlx/backend/metal/device.cpp | 4 ++-- mlx/backend/metal/kernels/CMakeLists.txt | 2 +- mlx/backend/metal/kernels/bf16.h | 2 +- mlx/backend/metal/kernels/bf16_math.h | 2 +- 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d9b7f10b7..b3df1b4ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,18 +83,17 @@ elseif (MLX_BUILD_METAL) OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY) + if (${MACOS_VERSION} LESS 14.0) + message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) + endif() message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) - if (${MACOS_VERSION} GREATER_EQUAL 15.0) - set(MLX_METAL_VERSION METAL_3_2) - elseif (${MACOS_VERSION} GREATER_EQUAL 14.2) - set(MLX_METAL_VERSION METAL_3_1) - elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) - set(MLX_METAL_VERSION METAL_3_0) - else() - message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) - endif() + # Get the metal version + execute_process( + COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" + OUTPUT_VARIABLE MLX_METAL_VERSION + COMMAND_ERROR_IS_FATAL ANY) FetchContent_Declare( metal_cpp @@ -113,7 +112,7 @@ elseif (MLX_BUILD_METAL) ${FOUNDATION_LIB} ${QUARTZ_LIB}) - add_compile_definitions(${MLX_METAL_VERSION}) + add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}") endif() if (MLX_BUILD_CPU) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 8839237fe..b23c5af36 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE) ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} ${SRC_FILE} - "-D${MLX_METAL_VERSION}" + "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}" DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN} diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 46e7d4c08..9ef15497b 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -30,9 +30,9 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2; constexpr const char* default_mtllib_path = METAL_PATH; constexpr auto get_metal_version() { -#if defined METAL_3_2 +#if (MLX_METAL_VERSION >= 320) return MTL::LanguageVersion3_2; -#elif defined METAL_3_1 +#elif (MLX_METAL_VERSION >= 310) return MTL::LanguageVersion3_1; #else return MTL::LanguageVersion3_0; diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4d36c6538..bb9f771b9 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -8,7 +8,7 @@ set( ) function(build_kernel_base TARGET SRCFILE DEPS) - set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) + set(METAL_FLAGS -Wall -Wextra -fno-fast-math) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only diff --git a/mlx/backend/metal/kernels/bf16.h b/mlx/backend/metal/kernels/bf16.h index c2dd5cc47..a30108261 100644 --- a/mlx/backend/metal/kernels/bf16.h +++ b/mlx/backend/metal/kernels/bf16.h @@ -6,7 +6,7 @@ using namespace metal; -#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) +#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) typedef bfloat bfloat16_t; diff --git a/mlx/backend/metal/kernels/bf16_math.h b/mlx/backend/metal/kernels/bf16_math.h index 7c0c04f19..8583a5140 100644 --- a/mlx/backend/metal/kernels/bf16_math.h +++ b/mlx/backend/metal/kernels/bf16_math.h @@ -369,7 +369,7 @@ instantiate_metal_math_funcs( return static_cast(__metal_simd_xor(static_cast(data))); \ } -#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) +#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) #define bfloat16_to_uint16(x) as_type(x) #define uint16_to_bfloat16(x) as_type(x)