diff --git a/examples/extensions/axpby/axpby.metal b/examples/extensions/axpby/axpby.metal index 503ad7444..7c5f32689 100644 --- a/examples/extensions/axpby/axpby.metal +++ b/examples/extensions/axpby/axpby.metal @@ -2,7 +2,6 @@ #include -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" template @@ -60,4 +59,4 @@ template instantiate_axpby(float32, float); instantiate_axpby(float16, half); instantiate_axpby(bfloat16, bfloat16_t); -instantiate_axpby(complex64, complex64_t); \ No newline at end of file +instantiate_axpby(complex64, complex64_t); diff --git a/mlx.pc.in b/mlx.pc.in index c3828b30b..c4e2515d7 100644 --- a/mlx.pc.in +++ b/mlx.pc.in @@ -28,10 +28,19 @@ endif() if (@MLX_BUILD_METAL@) set(MLX_BUILD_METAL @MLX_BUILD_METAL@) set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) - set_and_check(MLX_INCLUDE_DIRS - ${MLX_INCLUDE_DIRS} + set(MLX_INCLUDE_DIRS + "${MLX_INCLUDE_DIRS};" @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp ) + if(@MLX_METAL_VERSION@ GREATER_EQUAL 310) + set(MLX_INCLUDE_DIRS + "${MLX_INCLUDE_DIRS};" + @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1) + else() + set(MLX_INCLUDE_DIRS + "${MLX_INCLUDE_DIRS};" + @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0) + endif() endif() set_target_properties(mlx PROPERTIES @@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES ) include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) \ No newline at end of file +find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 4b14ebb55..2b3378862 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -21,7 +21,14 @@ function(make_jit_source SRC_FILE) target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) endfunction(make_jit_source) -make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h) +make_jit_source( + utils + kernels/jit/bf16.h + kernels/metal_3_0/bf16.h + kernels/metal_3_1/bf16.h + kernels/bf16_math.h + kernels/complex.h + kernels/defines.h) make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index be3e0bc83..8f8a4468d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -23,14 +23,18 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12; constexpr const char* default_mtllib_path = METAL_PATH; -constexpr auto get_metal_version() { -#if (MLX_METAL_VERSION >= 320) - return MTL::LanguageVersion3_2; -#elif (MLX_METAL_VERSION >= 310) - return MTL::LanguageVersion3_1; -#else - return MTL::LanguageVersion3_0; -#endif +auto get_metal_version() { + auto get_metal_version_ = []() { + if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { + return MTL::LanguageVersion3_2; + } else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) { + return MTL::LanguageVersion3_1; + } else { + return MTL::LanguageVersion3_0; + } + }; + static auto metal_version_ = get_metal_version_(); + return metal_version_; } auto load_device() { diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 9a8b9b7b4..978475c53 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -1,13 +1,27 @@ -set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h) +set(BASE_HEADERS + metal_3_1/bf16.h + metal_3_0/bf16.h + bf16_math.h + complex.h + defines.h + expm1f.h + utils.h) function(build_kernel_base TARGET SRCFILE DEPS) set(METAL_FLAGS -Wall -Wextra -fno-fast-math) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() + if(MLX_METAL_VERSION GREATER_EQUAL 310) + set(VERSION_INCLUDES + ${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1) + else() + set(VERSION_INCLUDES + ${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0) + endif() add_custom_command( COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} - -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air + -I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} OUTPUT ${TARGET}.air COMMENT "Building ${TARGET}.air" diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal index ebc81c630..c2e325697 100644 --- a/mlx/backend/metal/kernels/arange.metal +++ b/mlx/backend/metal/kernels/arange.metal @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/arange.h" #define instantiate_arange(tname, type) \ diff --git a/mlx/backend/metal/kernels/bf16_math.h b/mlx/backend/metal/kernels/bf16_math.h index 8583a5140..0643fb3ea 100644 --- a/mlx/backend/metal/kernels/bf16_math.h +++ b/mlx/backend/metal/kernels/bf16_math.h @@ -2,8 +2,6 @@ #pragma once -#include "mlx/backend/metal/kernels/bf16.h" - /////////////////////////////////////////////////////////////////////////////// // Metal math for bfloat16 /////////////////////////////////////////////////////////////////////////////// @@ -369,18 +367,6 @@ instantiate_metal_math_funcs( return static_cast(__metal_simd_xor(static_cast(data))); \ } -#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) - -#define bfloat16_to_uint16(x) as_type(x) -#define uint16_to_bfloat16(x) as_type(x) - -#else - -#define bfloat16_to_uint16(x) x.bits_ -#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()) - -#endif - namespace metal { instantiate_metal_simd_comm_funcs( diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 4798460df..13ee239dc 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -4,8 +4,8 @@ #include #include -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/utils.h" #define MLX_MTL_CONST static constant constexpr const diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 7036d4b81..ffbf2be7c 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -2,7 +2,6 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/copy.h" #define instantiate_copy_all(tname, itype, otype) \ diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index de63dbff6..1776c54e2 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -3,8 +3,6 @@ #include #include -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h" @@ -912,4 +910,4 @@ template < // clang-format off instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float16, half); -instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file +instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on diff --git a/mlx/backend/metal/kernels/gemv_masked.metal b/mlx/backend/metal/kernels/gemv_masked.metal index 7df97bac3..db787e7fc 100644 --- a/mlx/backend/metal/kernels/gemv_masked.metal +++ b/mlx/backend/metal/kernels/gemv_masked.metal @@ -4,8 +4,6 @@ #include #include -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/gemv_masked.h" diff --git a/mlx/backend/metal/kernels/jit/bf16.h b/mlx/backend/metal/kernels/jit/bf16.h new file mode 100644 index 000000000..702e8a4eb --- /dev/null +++ b/mlx/backend/metal/kernels/jit/bf16.h @@ -0,0 +1,16 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#define jit_if #if +#define jit_else #else +#define jit_endif #endif + +jit_if (__METAL_VERSION__ >= 310) + +#include "mlx/backend/metal/kernels/metal_3_1/bf16.h" + +jit_else + +#include "mlx/backend/metal/kernels/metal_3_0/bf16.h" + +jit_endif // clang-format on diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 53dc89eb0..462eb3b94 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -3,8 +3,6 @@ #include #include -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" using namespace metal; diff --git a/mlx/backend/metal/kernels/bf16.h b/mlx/backend/metal/kernels/metal_3_0/bf16.h similarity index 98% rename from mlx/backend/metal/kernels/bf16.h rename to mlx/backend/metal/kernels/metal_3_0/bf16.h index a30108261..f5d486706 100644 --- a/mlx/backend/metal/kernels/bf16.h +++ b/mlx/backend/metal/kernels/metal_3_0/bf16.h @@ -6,12 +6,6 @@ using namespace metal; -#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) - -typedef bfloat bfloat16_t; - -#else - ///////////////////////////////////////////////////////////////////////////// // Helpers ///////////////////////////////////////////////////////////////////////////// @@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) { } // namespace metal #pragma METAL internals : disable +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return x.bits_; +} -#endif - -#include "mlx/backend/metal/kernels/bf16_math.h" +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()); +} diff --git a/mlx/backend/metal/kernels/metal_3_1/bf16.h b/mlx/backend/metal/kernels/metal_3_1/bf16.h new file mode 100644 index 000000000..aa3c3c780 --- /dev/null +++ b/mlx/backend/metal/kernels/metal_3_1/bf16.h @@ -0,0 +1,16 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +typedef bfloat bfloat16_t; +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return as_type(x); +} + +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return as_type(x); +} diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index 7d89dd052..f8fb53dd5 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -3,8 +3,6 @@ #include #include -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" using namespace metal; diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index a38cfcdff..b8f7a7c03 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -2,7 +2,6 @@ #include -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" template void rope_single_impl( diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index 34eea5e17..1b64d59a1 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -6,8 +6,6 @@ using namespace metal; // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/softmax.h" diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index e0d9b6c69..68248484d 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -3,8 +3,6 @@ #include // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/sort.h" diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal index 1bc99ffb0..960321440 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal @@ -5,7 +5,7 @@ // clang-format off #include "mlx/backend/metal/kernels/steel/gemm/mma.h" -#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal index 099822c04..8f8b8e0e0 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal @@ -5,7 +5,7 @@ // clang-format off #include "mlx/backend/metal/kernels/steel/gemm/mma.h" -#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/utils.h" diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal index 9f33a2bc6..4333be26c 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index 52bb8bb41..c127893ff 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal index 9def75cda..739e3f30e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h" diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index dacafadef..f12e0048f 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -4,7 +4,6 @@ #include // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 33bf8fdae..151c6a64d 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -3,7 +3,13 @@ #pragma once #include -#include "mlx/backend/metal/kernels/bf16.h" + +// The correct bf16.h is included based on the metal version +// by giving the correct path to -I during compilation +// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0 +#include "bf16.h" + +#include "mlx/backend/metal/kernels/bf16_math.h" #include "mlx/backend/metal/kernels/complex.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh index 8e5d34b96..425cd8d70 100644 --- a/mlx/backend/metal/make_compiled_preamble.sh +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -11,12 +11,12 @@ SRC_DIR=$3 SRC_FILE=$4 CFLAGS=$5 SRC_NAME=$(basename -- "${SRC_FILE}") +JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp mkdir -p "$OUTPUT_DIR" - -CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null) +CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null) cat << EOF > "$OUTPUT_FILE" namespace mlx::core::metal {