Detect metal version and propagate correctly for JIT (#1109)

* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
This commit is contained in:
Awni Hannun 2024-05-15 17:42:09 -07:00 committed by GitHub
parent c417e42116
commit 1873ffda01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 33 additions and 7 deletions

View File

@ -87,9 +87,11 @@ elseif (MLX_BUILD_METAL)
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
@ -111,6 +113,8 @@ elseif (MLX_BUILD_METAL)
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
endif()
if (MLX_BUILD_CPU)

View File

@ -5,10 +5,16 @@ add_custom_command(
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
kernels/bf16.h
kernels/erf.h
kernels/expm1f.h
kernels/utils.h
kernels/bf16_math.h
)
add_custom_target(

View File

@ -29,6 +29,14 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if defined METAL_3_1
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@ -275,7 +283,12 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release();
// Throw error if unable to compile library
if (!mtl_lib) {

View File

@ -39,7 +39,7 @@ set(
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only

View File

@ -6,7 +6,9 @@
using namespace metal;
#if defined(__HAVE_BFLOAT__)
// No support for less than metal 3.0
// anything greater has native bfloat
#ifndef METAL_3_0
typedef bfloat bfloat16_t;
@ -312,6 +314,6 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#endif
#include "mlx/backend/metal/kernels/bf16_math.h"

View File

@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#ifndef METAL_3_0
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)

View File

@ -9,8 +9,9 @@
OUTPUT_FILE=$1
CC=$2
SRCDIR=$3
CFLAGS=$4
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
// Copyright © 2023-24 Apple Inc.