From 1873ffda018e4974ad0db3cbe014ed14e3b9a0f1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 15 May 2024 17:42:09 -0700 Subject: [PATCH] Detect metal version and propagate correctly for JIT (#1109) * detect metal version and propagate correctly for JIT * remove softmax * fix versions --- CMakeLists.txt | 4 ++++ mlx/backend/metal/CMakeLists.txt | 6 ++++++ mlx/backend/metal/device.cpp | 15 ++++++++++++++- mlx/backend/metal/kernels/CMakeLists.txt | 2 +- mlx/backend/metal/kernels/bf16.h | 6 ++++-- mlx/backend/metal/kernels/bf16_math.h | 4 ++-- mlx/backend/metal/make_compiled_preamble.sh | 3 ++- 7 files changed, 33 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6bf6d6697..facf0034f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,9 +87,11 @@ elseif (MLX_BUILD_METAL) if (${MACOS_VERSION} GREATER_EQUAL 14.2) set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip) + set(MLX_METAL_VERSION METAL_3_1) elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip) + set(MLX_METAL_VERSION METAL_3_0) else() message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) endif() @@ -111,6 +113,8 @@ elseif (MLX_BUILD_METAL) ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) + + add_compile_definitions(${MLX_METAL_VERSION}) endif() if (MLX_BUILD_CPU) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 0a77f0bda..ccc7fb9c3 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -5,10 +5,16 @@ add_custom_command( ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} + "-D${MLX_METAL_VERSION}" DEPENDS make_compiled_preamble.sh kernels/compiled_preamble.h kernels/unary.h kernels/binary.h + kernels/bf16.h + kernels/erf.h + kernels/expm1f.h + kernels/utils.h + kernels/bf16_math.h ) add_custom_target( diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 03974db3e..479d5dc64 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -29,6 +29,14 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2; constexpr const char* default_mtllib_path = METAL_PATH; +constexpr auto get_metal_version() { +#if defined METAL_3_1 + return MTL::LanguageVersion3_1; +#else + return MTL::LanguageVersion3_0; +#endif +} + auto load_device() { auto devices = MTL::CopyAllDevices(); auto device = static_cast(devices->object(0)) @@ -275,7 +283,12 @@ MTL::Library* Device::get_library_(const std::string& source_string) { NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding); NS::Error* error = nullptr; - auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error); + auto options = MTL::CompileOptions::alloc()->init(); + options->setFastMathEnabled(false); + + options->setLanguageVersion(get_metal_version()); + auto mtl_lib = device_->newLibrary(ns_code, options, &error); + options->release(); // Throw error if unable to compile library if (!mtl_lib) { diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 2010cb85a..ec406327e 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -39,7 +39,7 @@ set( ) function(build_kernel_base TARGET SRCFILE DEPS) - set(METAL_FLAGS -Wall -Wextra -fno-fast-math) + set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only diff --git a/mlx/backend/metal/kernels/bf16.h b/mlx/backend/metal/kernels/bf16.h index 71a45e1c5..03c73f9c2 100644 --- a/mlx/backend/metal/kernels/bf16.h +++ b/mlx/backend/metal/kernels/bf16.h @@ -6,7 +6,9 @@ using namespace metal; -#if defined(__HAVE_BFLOAT__) +// No support for less than metal 3.0 +// anything greater has native bfloat +#ifndef METAL_3_0 typedef bfloat bfloat16_t; @@ -312,6 +314,6 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) { #pragma METAL internals : disable -#endif // defined(__HAVE_BFLOAT__) +#endif #include "mlx/backend/metal/kernels/bf16_math.h" diff --git a/mlx/backend/metal/kernels/bf16_math.h b/mlx/backend/metal/kernels/bf16_math.h index f1fec4336..929429bdd 100644 --- a/mlx/backend/metal/kernels/bf16_math.h +++ b/mlx/backend/metal/kernels/bf16_math.h @@ -369,7 +369,7 @@ instantiate_metal_math_funcs( return static_cast(__metal_simd_xor(static_cast(data))); \ } -#if defined(__HAVE_BFLOAT__) +#ifndef METAL_3_0 #define bfloat16_to_uint16(x) as_type(x) #define uint16_to_bfloat16(x) as_type(x) @@ -391,4 +391,4 @@ instantiate_metal_simd_comm_funcs( uint16_to_bfloat16); instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); -} // namespace metal \ No newline at end of file +} // namespace metal diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh index 26b575de4..dedd38a64 100644 --- a/mlx/backend/metal/make_compiled_preamble.sh +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -9,8 +9,9 @@ OUTPUT_FILE=$1 CC=$2 SRCDIR=$3 +CFLAGS=$4 -CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null) +CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null) cat << EOF > "$OUTPUT_FILE" // Copyright © 2023-24 Apple Inc.