mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-23 16:46:50 +08:00
Dispatch bf16 at run time when using the JIT (#1584)
* Dispatch bf16 at run time when using the JIT * fix extension * fix extension build * fix extension build * Update utils.h
This commit is contained in:
parent
b35f1e3c9c
commit
610af352d4
@ -2,7 +2,6 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
@ -60,4 +59,4 @@ template <typename T>
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
15
mlx.pc.in
15
mlx.pc.in
@ -28,10 +28,19 @@ endif()
|
||||
if (@MLX_BUILD_METAL@)
|
||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||
set_and_check(MLX_INCLUDE_DIRS
|
||||
${MLX_INCLUDE_DIRS}
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||
)
|
||||
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(mlx PROPERTIES
|
||||
@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
|
||||
)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
|
@ -21,7 +21,14 @@ function(make_jit_source SRC_FILE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/jit/bf16.h
|
||||
kernels/metal_3_0/bf16.h
|
||||
kernels/metal_3_1/bf16.h
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
|
@ -23,14 +23,18 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if (MLX_METAL_VERSION >= 320)
|
||||
return MTL::LanguageVersion3_2;
|
||||
#elif (MLX_METAL_VERSION >= 310)
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
auto get_metal_version() {
|
||||
auto get_metal_version_ = []() {
|
||||
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||
return MTL::LanguageVersion3_2;
|
||||
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
||||
return MTL::LanguageVersion3_1;
|
||||
} else {
|
||||
return MTL::LanguageVersion3_0;
|
||||
}
|
||||
};
|
||||
static auto metal_version_ = get_metal_version_();
|
||||
return metal_version_;
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
|
@ -1,13 +1,27 @@
|
||||
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
||||
set(BASE_HEADERS
|
||||
metal_3_1/bf16.h
|
||||
metal_3_0/bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
expm1f.h
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
|
@ -4,8 +4,8 @@
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
|
@ -3,8 +3,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
@ -912,4 +910,4 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
@ -4,8 +4,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||
|
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#define jit_if #if
|
||||
#define jit_else #else
|
||||
#define jit_endif #endif
|
||||
|
||||
jit_if (__METAL_VERSION__ >= 310)
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
||||
|
||||
jit_else
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
||||
|
||||
jit_endif // clang-format on
|
@ -3,8 +3,6 @@
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
@ -6,12 +6,6 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return x.bits_;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return as_type<uint16_t>(x);
|
||||
}
|
||||
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return as_type<bfloat16_t>(x);
|
||||
}
|
@ -3,8 +3,6 @@
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
template <typename T, bool traditional, bool forward>
|
||||
void rope_single_impl(
|
||||
|
@ -6,8 +6,6 @@
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/softmax.h"
|
||||
|
||||
|
@ -3,8 +3,6 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/sort.h"
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||
|
@ -5,7 +5,7 @@
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <metal_math>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
@ -3,7 +3,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
// The correct bf16.h is included based on the metal version
|
||||
// by giving the correct path to -I during compilation
|
||||
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
|
||||
#include "bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
#include "mlx/backend/metal/kernels/complex.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
||||
|
@ -11,12 +11,12 @@ SRC_DIR=$3
|
||||
SRC_FILE=$4
|
||||
CFLAGS=$5
|
||||
SRC_NAME=$(basename -- "${SRC_FILE}")
|
||||
JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
|
||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
namespace mlx::core::metal {
|
||||
|
Loading…
Reference in New Issue
Block a user