mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT * remove softmax * fix versions
This commit is contained in:
parent
c417e42116
commit
1873ffda01
@ -87,9 +87,11 @@ elseif (MLX_BUILD_METAL)
|
|||||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||||
|
set(MLX_METAL_VERSION METAL_3_1)
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||||
|
set(MLX_METAL_VERSION METAL_3_0)
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||||
endif()
|
endif()
|
||||||
@ -111,6 +113,8 @@ elseif (MLX_BUILD_METAL)
|
|||||||
${METAL_LIB}
|
${METAL_LIB}
|
||||||
${FOUNDATION_LIB}
|
${FOUNDATION_LIB}
|
||||||
${QUARTZ_LIB})
|
${QUARTZ_LIB})
|
||||||
|
|
||||||
|
add_compile_definitions(${MLX_METAL_VERSION})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_CPU)
|
if (MLX_BUILD_CPU)
|
||||||
|
@ -5,10 +5,16 @@ add_custom_command(
|
|||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||||
${CMAKE_C_COMPILER}
|
${CMAKE_C_COMPILER}
|
||||||
${PROJECT_SOURCE_DIR}
|
${PROJECT_SOURCE_DIR}
|
||||||
|
"-D${MLX_METAL_VERSION}"
|
||||||
DEPENDS make_compiled_preamble.sh
|
DEPENDS make_compiled_preamble.sh
|
||||||
kernels/compiled_preamble.h
|
kernels/compiled_preamble.h
|
||||||
kernels/unary.h
|
kernels/unary.h
|
||||||
kernels/binary.h
|
kernels/binary.h
|
||||||
|
kernels/bf16.h
|
||||||
|
kernels/erf.h
|
||||||
|
kernels/expm1f.h
|
||||||
|
kernels/utils.h
|
||||||
|
kernels/bf16_math.h
|
||||||
)
|
)
|
||||||
|
|
||||||
add_custom_target(
|
add_custom_target(
|
||||||
|
@ -29,6 +29,14 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
|||||||
|
|
||||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||||
|
|
||||||
|
constexpr auto get_metal_version() {
|
||||||
|
#if defined METAL_3_1
|
||||||
|
return MTL::LanguageVersion3_1;
|
||||||
|
#else
|
||||||
|
return MTL::LanguageVersion3_0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
auto load_device() {
|
auto load_device() {
|
||||||
auto devices = MTL::CopyAllDevices();
|
auto devices = MTL::CopyAllDevices();
|
||||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||||
@ -275,7 +283,12 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
|||||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||||
|
|
||||||
NS::Error* error = nullptr;
|
NS::Error* error = nullptr;
|
||||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
auto options = MTL::CompileOptions::alloc()->init();
|
||||||
|
options->setFastMathEnabled(false);
|
||||||
|
|
||||||
|
options->setLanguageVersion(get_metal_version());
|
||||||
|
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||||
|
options->release();
|
||||||
|
|
||||||
// Throw error if unable to compile library
|
// Throw error if unable to compile library
|
||||||
if (!mtl_lib) {
|
if (!mtl_lib) {
|
||||||
|
@ -39,7 +39,7 @@ set(
|
|||||||
)
|
)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
-gline-tables-only
|
-gline-tables-only
|
||||||
|
@ -6,7 +6,9 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
// No support for less than metal 3.0
|
||||||
|
// anything greater has native bfloat
|
||||||
|
#ifndef METAL_3_0
|
||||||
|
|
||||||
typedef bfloat bfloat16_t;
|
typedef bfloat bfloat16_t;
|
||||||
|
|
||||||
@ -312,6 +314,6 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
|||||||
|
|
||||||
#pragma METAL internals : disable
|
#pragma METAL internals : disable
|
||||||
|
|
||||||
#endif // defined(__HAVE_BFLOAT__)
|
#endif
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||||
|
@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
|||||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#ifndef METAL_3_0
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||||
@ -391,4 +391,4 @@ instantiate_metal_simd_comm_funcs(
|
|||||||
uint16_to_bfloat16);
|
uint16_to_bfloat16);
|
||||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
|
@ -9,8 +9,9 @@
|
|||||||
OUTPUT_FILE=$1
|
OUTPUT_FILE=$1
|
||||||
CC=$2
|
CC=$2
|
||||||
SRCDIR=$3
|
SRCDIR=$3
|
||||||
|
CFLAGS=$4
|
||||||
|
|
||||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
|
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null)
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
// Copyright © 2023-24 Apple Inc.
|
// Copyright © 2023-24 Apple Inc.
|
||||||
|
Loading…
Reference in New Issue
Block a user