mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Dispatch bf16 at run time when using the JIT (#1584)
* Dispatch bf16 at run time when using the JIT * fix extension * fix extension build * fix extension build * Update utils.h
This commit is contained in:
parent
b35f1e3c9c
commit
610af352d4
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
13
mlx.pc.in
13
mlx.pc.in
@ -28,10 +28,19 @@ endif()
|
|||||||
if (@MLX_BUILD_METAL@)
|
if (@MLX_BUILD_METAL@)
|
||||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||||
set_and_check(MLX_INCLUDE_DIRS
|
set(MLX_INCLUDE_DIRS
|
||||||
${MLX_INCLUDE_DIRS}
|
"${MLX_INCLUDE_DIRS};"
|
||||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||||
)
|
)
|
||||||
|
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||||
|
set(MLX_INCLUDE_DIRS
|
||||||
|
"${MLX_INCLUDE_DIRS};"
|
||||||
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||||
|
else()
|
||||||
|
set(MLX_INCLUDE_DIRS
|
||||||
|
"${MLX_INCLUDE_DIRS};"
|
||||||
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set_target_properties(mlx PROPERTIES
|
set_target_properties(mlx PROPERTIES
|
||||||
|
@ -21,7 +21,14 @@ function(make_jit_source SRC_FILE)
|
|||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||||
endfunction(make_jit_source)
|
endfunction(make_jit_source)
|
||||||
|
|
||||||
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
make_jit_source(
|
||||||
|
utils
|
||||||
|
kernels/jit/bf16.h
|
||||||
|
kernels/metal_3_0/bf16.h
|
||||||
|
kernels/metal_3_1/bf16.h
|
||||||
|
kernels/bf16_math.h
|
||||||
|
kernels/complex.h
|
||||||
|
kernels/defines.h)
|
||||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||||
make_jit_source(binary_ops)
|
make_jit_source(binary_ops)
|
||||||
make_jit_source(ternary_ops)
|
make_jit_source(ternary_ops)
|
||||||
|
@ -23,14 +23,18 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
|||||||
|
|
||||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||||
|
|
||||||
constexpr auto get_metal_version() {
|
auto get_metal_version() {
|
||||||
#if (MLX_METAL_VERSION >= 320)
|
auto get_metal_version_ = []() {
|
||||||
return MTL::LanguageVersion3_2;
|
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||||
#elif (MLX_METAL_VERSION >= 310)
|
return MTL::LanguageVersion3_2;
|
||||||
return MTL::LanguageVersion3_1;
|
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
||||||
#else
|
return MTL::LanguageVersion3_1;
|
||||||
return MTL::LanguageVersion3_0;
|
} else {
|
||||||
#endif
|
return MTL::LanguageVersion3_0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static auto metal_version_ = get_metal_version_();
|
||||||
|
return metal_version_;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto load_device() {
|
auto load_device() {
|
||||||
|
@ -1,13 +1,27 @@
|
|||||||
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
set(BASE_HEADERS
|
||||||
|
metal_3_1/bf16.h
|
||||||
|
metal_3_0/bf16.h
|
||||||
|
bf16_math.h
|
||||||
|
complex.h
|
||||||
|
defines.h
|
||||||
|
expm1f.h
|
||||||
|
utils.h)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||||
endif()
|
endif()
|
||||||
|
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||||
|
set(VERSION_INCLUDES
|
||||||
|
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||||
|
else()
|
||||||
|
set(VERSION_INCLUDES
|
||||||
|
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
||||||
|
endif()
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
||||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||||
OUTPUT ${TARGET}.air
|
OUTPUT ${TARGET}.air
|
||||||
COMMENT "Building ${TARGET}.air"
|
COMMENT "Building ${TARGET}.air"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/arange.h"
|
#include "mlx/backend/metal/kernels/arange.h"
|
||||||
|
|
||||||
#define instantiate_arange(tname, type) \
|
#define instantiate_arange(tname, type) \
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Metal math for bfloat16
|
// Metal math for bfloat16
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
|
|||||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
|
||||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) x.bits_
|
|
||||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
instantiate_metal_simd_comm_funcs(
|
instantiate_metal_simd_comm_funcs(
|
||||||
|
@ -4,8 +4,8 @@
|
|||||||
#include <metal_simdgroup_matrix>
|
#include <metal_simdgroup_matrix>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/copy.h"
|
#include "mlx/backend/metal/kernels/copy.h"
|
||||||
|
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
|
@ -3,8 +3,6 @@
|
|||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||||
|
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
16
mlx/backend/metal/kernels/jit/bf16.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define jit_if #if
|
||||||
|
#define jit_else #else
|
||||||
|
#define jit_endif #endif
|
||||||
|
|
||||||
|
jit_if (__METAL_VERSION__ >= 310)
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
||||||
|
|
||||||
|
jit_else
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
||||||
|
|
||||||
|
jit_endif // clang-format on
|
@ -3,8 +3,6 @@
|
|||||||
#include <metal_common>
|
#include <metal_common>
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
@ -6,12 +6,6 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
|
||||||
|
|
||||||
typedef bfloat bfloat16_t;
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Helpers
|
// Helpers
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
|||||||
} // namespace metal
|
} // namespace metal
|
||||||
|
|
||||||
#pragma METAL internals : disable
|
#pragma METAL internals : disable
|
||||||
|
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||||
|
return x.bits_;
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||||
|
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
||||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
}
|
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
16
mlx/backend/metal/kernels/metal_3_1/bf16.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
typedef bfloat bfloat16_t;
|
||||||
|
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||||
|
return as_type<uint16_t>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||||
|
return as_type<bfloat16_t>(x);
|
||||||
|
}
|
@ -3,8 +3,6 @@
|
|||||||
#include <metal_common>
|
#include <metal_common>
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
template <typename T, bool traditional, bool forward>
|
template <typename T, bool traditional, bool forward>
|
||||||
void rope_single_impl(
|
void rope_single_impl(
|
||||||
|
@ -6,8 +6,6 @@
|
|||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/softmax.h"
|
#include "mlx/backend/metal/kernels/softmax.h"
|
||||||
|
|
||||||
|
@ -3,8 +3,6 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/sort.h"
|
#include "mlx/backend/metal/kernels/sort.h"
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/ternary.h"
|
#include "mlx/backend/metal/kernels/ternary.h"
|
||||||
|
@ -3,7 +3,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
|
// The correct bf16.h is included based on the metal version
|
||||||
|
// by giving the correct path to -I during compilation
|
||||||
|
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
|
||||||
|
#include "bf16.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||||
#include "mlx/backend/metal/kernels/complex.h"
|
#include "mlx/backend/metal/kernels/complex.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
|
||||||
|
@ -11,12 +11,12 @@ SRC_DIR=$3
|
|||||||
SRC_FILE=$4
|
SRC_FILE=$4
|
||||||
CFLAGS=$5
|
CFLAGS=$5
|
||||||
SRC_NAME=$(basename -- "${SRC_FILE}")
|
SRC_NAME=$(basename -- "${SRC_FILE}")
|
||||||
|
JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
|
||||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||||
|
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||||
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
Loading…
Reference in New Issue
Block a user