mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT * remove softmax * fix versions
This commit is contained in:
parent
c417e42116
commit
1873ffda01
@ -87,9 +87,11 @@ elseif (MLX_BUILD_METAL)
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_1)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_0)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
@ -111,6 +113,8 @@ elseif (MLX_BUILD_METAL)
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions(${MLX_METAL_VERSION})
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_CPU)
|
||||
|
@ -5,10 +5,16 @@ add_custom_command(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
kernels/bf16.h
|
||||
kernels/erf.h
|
||||
kernels/expm1f.h
|
||||
kernels/utils.h
|
||||
kernels/bf16_math.h
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
|
@ -29,6 +29,14 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if defined METAL_3_1
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@ -275,7 +283,12 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||
auto options = MTL::CompileOptions::alloc()->init();
|
||||
options->setFastMathEnabled(false);
|
||||
|
||||
options->setLanguageVersion(get_metal_version());
|
||||
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||
options->release();
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
|
@ -39,7 +39,7 @@ set(
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
|
@ -6,7 +6,9 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
// No support for less than metal 3.0
|
||||
// anything greater has native bfloat
|
||||
#ifndef METAL_3_0
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
@ -312,6 +314,6 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
|
@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#ifndef METAL_3_0
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
@ -9,8 +9,9 @@
|
||||
OUTPUT_FILE=$1
|
||||
CC=$2
|
||||
SRCDIR=$3
|
||||
CFLAGS=$4
|
||||
|
||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
|
||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
Loading…
Reference in New Issue
Block a user