Get metal version from xcode (#1228)

* get metal version from xcode

* typo

* fix
This commit is contained in:
Awni Hannun 2024-06-26 07:02:11 -07:00 committed by GitHub
parent 4eef1e8a3e
commit 56c8a33439
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 15 additions and 16 deletions

View File

@ -83,18 +83,17 @@ elseif (MLX_BUILD_METAL)
OUTPUT_VARIABLE MACOS_VERSION OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY) COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
if (${MACOS_VERSION} GREATER_EQUAL 15.0) # Get the metal version
set(MLX_METAL_VERSION METAL_3_2) execute_process(
elseif (${MACOS_VERSION} GREATER_EQUAL 14.2) COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
set(MLX_METAL_VERSION METAL_3_1) OUTPUT_VARIABLE MLX_METAL_VERSION
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) COMMAND_ERROR_IS_FATAL ANY)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare( FetchContent_Declare(
metal_cpp metal_cpp
@ -113,7 +112,7 @@ elseif (MLX_BUILD_METAL)
${FOUNDATION_LIB} ${FOUNDATION_LIB}
${QUARTZ_LIB}) ${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION}) add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif() endif()
if (MLX_BUILD_CPU) if (MLX_BUILD_CPU)

View File

@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
${CMAKE_C_COMPILER} ${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}
${SRC_FILE} ${SRC_FILE}
"-D${MLX_METAL_VERSION}" "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h kernels/${SRC_FILE}.h
${ARGN} ${ARGN}

View File

@ -30,9 +30,9 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() { constexpr auto get_metal_version() {
#if defined METAL_3_2 #if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2; return MTL::LanguageVersion3_2;
#elif defined METAL_3_1 #elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1; return MTL::LanguageVersion3_1;
#else #else
return MTL::LanguageVersion3_0; return MTL::LanguageVersion3_0;

View File

@ -8,7 +8,7 @@ set(
) )
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only -gline-tables-only

View File

@ -6,7 +6,7 @@
using namespace metal; using namespace metal;
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) #if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t; typedef bfloat bfloat16_t;

View File

@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
} }
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) #if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x) #define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x) #define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)