mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Get metal version from xcode (#1228)
* get metal version from xcode * typo * fix
This commit is contained in:
parent
4eef1e8a3e
commit
56c8a33439
@ -83,18 +83,17 @@ elseif (MLX_BUILD_METAL)
|
|||||||
OUTPUT_VARIABLE MACOS_VERSION
|
OUTPUT_VARIABLE MACOS_VERSION
|
||||||
COMMAND_ERROR_IS_FATAL ANY)
|
COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
|
if (${MACOS_VERSION} LESS 14.0)
|
||||||
|
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||||
|
endif()
|
||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
|
||||||
if (${MACOS_VERSION} GREATER_EQUAL 15.0)
|
# Get the metal version
|
||||||
set(MLX_METAL_VERSION METAL_3_2)
|
execute_process(
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
set(MLX_METAL_VERSION METAL_3_1)
|
OUTPUT_VARIABLE MLX_METAL_VERSION
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
COMMAND_ERROR_IS_FATAL ANY)
|
||||||
set(MLX_METAL_VERSION METAL_3_0)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
|
||||||
endif()
|
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
metal_cpp
|
metal_cpp
|
||||||
@ -113,7 +112,7 @@ elseif (MLX_BUILD_METAL)
|
|||||||
${FOUNDATION_LIB}
|
${FOUNDATION_LIB}
|
||||||
${QUARTZ_LIB})
|
${QUARTZ_LIB})
|
||||||
|
|
||||||
add_compile_definitions(${MLX_METAL_VERSION})
|
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_CPU)
|
if (MLX_BUILD_CPU)
|
||||||
|
@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
|
|||||||
${CMAKE_C_COMPILER}
|
${CMAKE_C_COMPILER}
|
||||||
${PROJECT_SOURCE_DIR}
|
${PROJECT_SOURCE_DIR}
|
||||||
${SRC_FILE}
|
${SRC_FILE}
|
||||||
"-D${MLX_METAL_VERSION}"
|
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||||
DEPENDS make_compiled_preamble.sh
|
DEPENDS make_compiled_preamble.sh
|
||||||
kernels/${SRC_FILE}.h
|
kernels/${SRC_FILE}.h
|
||||||
${ARGN}
|
${ARGN}
|
||||||
|
@ -30,9 +30,9 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
|||||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||||
|
|
||||||
constexpr auto get_metal_version() {
|
constexpr auto get_metal_version() {
|
||||||
#if defined METAL_3_2
|
#if (MLX_METAL_VERSION >= 320)
|
||||||
return MTL::LanguageVersion3_2;
|
return MTL::LanguageVersion3_2;
|
||||||
#elif defined METAL_3_1
|
#elif (MLX_METAL_VERSION >= 310)
|
||||||
return MTL::LanguageVersion3_1;
|
return MTL::LanguageVersion3_1;
|
||||||
#else
|
#else
|
||||||
return MTL::LanguageVersion3_0;
|
return MTL::LanguageVersion3_0;
|
||||||
|
@ -8,7 +8,7 @@ set(
|
|||||||
)
|
)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
-gline-tables-only
|
-gline-tables-only
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310)
|
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||||
|
|
||||||
typedef bfloat bfloat16_t;
|
typedef bfloat bfloat16_t;
|
||||||
|
|
||||||
|
@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
|||||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310)
|
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user