diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a3a8b6d9..d14aa2afe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,13 +85,12 @@ elseif (MLX_BUILD_METAL) message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") - if (${MACOS_VERSION} GREATER_EQUAL 14.2) - set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff) - set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip) + set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) + if (${MACOS_VERSION} GREATER_EQUAL 15.0) + set(MLX_METAL_VERSION METAL_3_2) + elseif (${MACOS_VERSION} GREATER_EQUAL 14.2) set(MLX_METAL_VERSION METAL_3_1) elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) - set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff) - set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip) set(MLX_METAL_VERSION METAL_3_0) else() message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) @@ -100,7 +99,6 @@ elseif (MLX_BUILD_METAL) FetchContent_Declare( metal_cpp URL ${METAL_CPP_URL} - PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true ) FetchContent_MakeAvailable(metal_cpp) diff --git a/cmake/metal.14.0.diff b/cmake/metal.14.0.diff deleted file mode 100644 index 3609fd916..000000000 --- a/cmake/metal.14.0.diff +++ /dev/null @@ -1,36 +0,0 @@ -diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp ---- Metal/MTLEvent.hpp 2023-06-01 12:18:26 -+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59 -@@ -62,6 +62,7 @@ - - uint64_t signaledValue() const; - void setSignaledValue(uint64_t signaledValue); -+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS); - }; - - class SharedEventHandle : public NS::SecureCoding -@@ -138,6 +139,11 @@ - _MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) - { - Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); -+} -+ -+// method: waitUntilSignaledValue -+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) { -+ return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS); - } - - // static method: alloc -diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp ---- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26 -+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29 -@@ -1906,6 +1906,9 @@ - "setShouldMaximizeConcurrentCompilation:"); - _MTL_PRIVATE_DEF_SEL(setSignaledValue_, - "setSignaledValue:"); -+_MTL_PRIVATE_DEF_SEL( -+ waitUntilSignaledValue_timeoutMS_, -+ "waitUntilSignaledValue:timeoutMS:"); - _MTL_PRIVATE_DEF_SEL(setSize_, - "setSize:"); - _MTL_PRIVATE_DEF_SEL(setSlice_, diff --git a/cmake/metal.14.2.diff b/cmake/metal.14.2.diff deleted file mode 100644 index 8634afaa7..000000000 --- a/cmake/metal.14.2.diff +++ /dev/null @@ -1,36 +0,0 @@ -diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp ---- Metal/MTLEvent.hpp 2024-04-15 07:12:10 -+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50 -@@ -62,6 +62,7 @@ - - uint64_t signaledValue() const; - void setSignaledValue(uint64_t signaledValue); -+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS); - }; - - class SharedEventHandle : public NS::SecureCoding -@@ -138,6 +139,11 @@ - _MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) - { - Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); -+} -+ -+// method: waitUntilSignaledValue -+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) { -+ return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS); - } - - // static method: alloc -diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp ---- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10 -+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15 -@@ -1918,6 +1918,9 @@ - "setShouldMaximizeConcurrentCompilation:"); - _MTL_PRIVATE_DEF_SEL(setSignaledValue_, - "setSignaledValue:"); -+_MTL_PRIVATE_DEF_SEL( -+ waitUntilSignaledValue_timeoutMS_, -+ "waitUntilSignaledValue:timeoutMS:"); - _MTL_PRIVATE_DEF_SEL(setSize_, - "setSize:"); - _MTL_PRIVATE_DEF_SEL(setSlice_, diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index a22d8dd0e..46e7d4c08 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -30,7 +30,9 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2; constexpr const char* default_mtllib_path = METAL_PATH; constexpr auto get_metal_version() { -#if defined METAL_3_1 +#if defined METAL_3_2 + return MTL::LanguageVersion3_2; +#elif defined METAL_3_1 return MTL::LanguageVersion3_1; #else return MTL::LanguageVersion3_0; diff --git a/mlx/backend/metal/kernels/bf16.h b/mlx/backend/metal/kernels/bf16.h index 726b676bb..c2dd5cc47 100644 --- a/mlx/backend/metal/kernels/bf16.h +++ b/mlx/backend/metal/kernels/bf16.h @@ -6,7 +6,7 @@ using namespace metal; -#if defined METAL_3_1 || (__METAL_VERSION__ >= 310) +#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) typedef bfloat bfloat16_t; diff --git a/mlx/backend/metal/kernels/bf16_math.h b/mlx/backend/metal/kernels/bf16_math.h index 8c48b8cfd..7c0c04f19 100644 --- a/mlx/backend/metal/kernels/bf16_math.h +++ b/mlx/backend/metal/kernels/bf16_math.h @@ -369,7 +369,7 @@ instantiate_metal_math_funcs( return static_cast(__metal_simd_xor(static_cast(data))); \ } -#if defined METAL_3_1 || (__METAL_VERSION__ >= 310) +#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) #define bfloat16_to_uint16(x) as_type(x) #define uint16_to_bfloat16(x) as_type(x)