Get metal version from xcode (#1228)

* get metal version from xcode

* typo

* fix
This commit is contained in:
Awni Hannun 2024-06-26 07:02:11 -07:00 committed by GitHub
parent 4eef1e8a3e
commit 56c8a33439
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 15 additions and 16 deletions

View File

@ -83,18 +83,17 @@ elseif (MLX_BUILD_METAL)
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
if (${MACOS_VERSION} GREATER_EQUAL 15.0)
set(MLX_METAL_VERSION METAL_3_2)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
# Get the metal version
execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION
COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(
metal_cpp
@ -113,7 +112,7 @@ elseif (MLX_BUILD_METAL)
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if (MLX_BUILD_CPU)

View File

@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-D${MLX_METAL_VERSION}"
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}

View File

@ -30,9 +30,9 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if defined METAL_3_2
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif defined METAL_3_1
#elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;

View File

@ -8,7 +8,7 @@ set(
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only

View File

@ -6,7 +6,7 @@
using namespace metal;
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310)
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;

View File

@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310)
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)