Dispatch bf16 at run time when using the JIT (#1584)

* Dispatch bf16 at run time when using the JIT

* fix extension

* fix extension build

* fix extension build

* Update utils.h
This commit is contained in:
Awni Hannun 2024-11-15 16:54:36 -08:00 committed by GitHub
parent b35f1e3c9c
commit 610af352d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 102 additions and 65 deletions

View File

@ -2,7 +2,6 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T> template <typename T>
@ -60,4 +59,4 @@ template <typename T>
instantiate_axpby(float32, float); instantiate_axpby(float32, float);
instantiate_axpby(float16, half); instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);

View File

@ -28,10 +28,19 @@ endif()
if (@MLX_BUILD_METAL@) if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@) set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set_and_check(MLX_INCLUDE_DIRS set(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS} "${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
) )
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif() endif()
set_target_properties(mlx PROPERTIES set_target_properties(mlx PROPERTIES
@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
) )
include(FindPackageHandleStandardArgs) include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@ -21,7 +21,14 @@ function(make_jit_source SRC_FILE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source) endfunction(make_jit_source)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h) make_jit_source(
utils
kernels/jit/bf16.h
kernels/metal_3_0/bf16.h
kernels/metal_3_1/bf16.h
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)

View File

@ -23,14 +23,18 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() { auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320) auto get_metal_version_ = []() {
return MTL::LanguageVersion3_2; if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
#elif (MLX_METAL_VERSION >= 310) return MTL::LanguageVersion3_2;
return MTL::LanguageVersion3_1; } else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
#else return MTL::LanguageVersion3_1;
return MTL::LanguageVersion3_0; } else {
#endif return MTL::LanguageVersion3_0;
}
};
static auto metal_version_ = get_metal_version_();
return metal_version_;
} }
auto load_device() { auto load_device() {

View File

@ -1,13 +1,27 @@
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h) set(BASE_HEADERS
metal_3_1/bf16.h
metal_3_0/bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
else()
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air -I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air" COMMENT "Building ${TARGET}.air"

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \

View File

@ -2,8 +2,6 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/bf16.h"
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16 // Metal math for bfloat16
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
} }
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal { namespace metal {
instantiate_metal_simd_comm_funcs( instantiate_metal_simd_comm_funcs(

View File

@ -4,8 +4,8 @@
#include <metal_simdgroup_matrix> #include <metal_simdgroup_matrix>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const

View File

@ -2,7 +2,6 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \

View File

@ -3,8 +3,6 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"
@ -912,4 +910,4 @@ template <
// clang-format off // clang-format off
instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half); instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -4,8 +4,6 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/gemv_masked.h" #include "mlx/backend/metal/kernels/gemv_masked.h"

View File

@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#define jit_if #if
#define jit_else #else
#define jit_endif #endif
jit_if (__METAL_VERSION__ >= 310)
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
jit_else
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
jit_endif // clang-format on

View File

@ -3,8 +3,6 @@
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;

View File

@ -6,12 +6,6 @@
using namespace metal; using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;
#else
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Helpers // Helpers
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
} // namespace metal } // namespace metal
#pragma METAL internals : disable #pragma METAL internals : disable
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return x.bits_;
}
#endif inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
#include "mlx/backend/metal/kernels/bf16_math.h" }

View File

@ -0,0 +1,16 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
typedef bfloat bfloat16_t;
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return as_type<uint16_t>(x);
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return as_type<bfloat16_t>(x);
}

View File

@ -3,8 +3,6 @@
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;

View File

@ -2,7 +2,6 @@
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>
void rope_single_impl( void rope_single_impl(

View File

@ -6,8 +6,6 @@
using namespace metal; using namespace metal;
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"

View File

@ -3,8 +3,6 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sort.h" #include "mlx/backend/metal/kernels/sort.h"

View File

@ -5,7 +5,7 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@ -5,7 +5,7 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"

View File

@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"

View File

@ -4,7 +4,6 @@
#include <metal_math> #include <metal_math>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/ternary.h"

View File

@ -3,7 +3,13 @@
#pragma once #pragma once
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
// The correct bf16.h is included based on the metal version
// by giving the correct path to -I during compilation
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
#include "bf16.h"
#include "mlx/backend/metal/kernels/bf16_math.h"
#include "mlx/backend/metal/kernels/complex.h" #include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"

View File

@ -11,12 +11,12 @@ SRC_DIR=$3
SRC_FILE=$4 SRC_FILE=$4
CFLAGS=$5 CFLAGS=$5
SRC_NAME=$(basename -- "${SRC_FILE}") SRC_NAME=$(basename -- "${SRC_FILE}")
JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal { namespace mlx::core::metal {