mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 19:31:16 +08:00
127 lines
4.9 KiB
CMake
127 lines
4.9 KiB
CMake
function(make_jit_source SRC_FILE)
|
|
# This function takes a metal header file, runs the C preprocessesor on it,
|
|
# and makes the processed contents available as a string in a C++ function
|
|
# mlx::core::metal::${SRC_NAME}()
|
|
#
|
|
# To use the function, declare it in jit/includes.h and include
|
|
# jit/includes.h.
|
|
#
|
|
# Additional arguments to this function are treated as dependencies in the
|
|
# Cmake build system.
|
|
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
|
add_custom_command(
|
|
OUTPUT jit/${SRC_NAME}.cpp
|
|
COMMAND
|
|
bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
|
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
|
${SRC_FILE}
|
|
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
|
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
|
add_dependencies(mlx ${SRC_NAME})
|
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
|
endfunction(make_jit_source)
|
|
|
|
make_jit_source(
|
|
utils
|
|
kernels/jit/bf16.h
|
|
kernels/metal_3_0/bf16.h
|
|
kernels/metal_3_1/bf16.h
|
|
kernels/bf16_math.h
|
|
kernels/complex.h
|
|
kernels/defines.h)
|
|
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
|
make_jit_source(binary_ops)
|
|
make_jit_source(ternary_ops)
|
|
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
|
make_jit_source(scatter kernels/indexing.h)
|
|
make_jit_source(gather kernels/indexing.h)
|
|
make_jit_source(gather_axis)
|
|
make_jit_source(scatter_axis)
|
|
make_jit_source(hadamard)
|
|
|
|
if(MLX_METAL_JIT)
|
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
|
|
make_jit_source(arange)
|
|
make_jit_source(copy)
|
|
make_jit_source(unary)
|
|
make_jit_source(binary)
|
|
make_jit_source(binary_two)
|
|
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
|
make_jit_source(logsumexp)
|
|
make_jit_source(ternary)
|
|
make_jit_source(softmax)
|
|
make_jit_source(scan)
|
|
make_jit_source(sort)
|
|
make_jit_source(
|
|
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
|
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
|
make_jit_source(
|
|
steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
|
|
kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
|
|
kernels/steel/gemm/transforms.h)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
|
make_jit_source(
|
|
steel/conv/conv
|
|
kernels/steel/utils.h
|
|
kernels/steel/defines.h
|
|
kernels/steel/gemm/mma.h
|
|
kernels/steel/gemm/transforms.h
|
|
kernels/steel/conv/params.h
|
|
kernels/steel/conv/loader.h
|
|
kernels/steel/conv/loaders/loader_channel_l.h
|
|
kernels/steel/conv/loaders/loader_channel_n.h)
|
|
make_jit_source(steel/conv/kernels/steel_conv)
|
|
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
|
kernels/steel/conv/loaders/loader_general.h)
|
|
make_jit_source(quantized)
|
|
make_jit_source(gemv_masked)
|
|
else()
|
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
|
endif()
|
|
|
|
target_sources(
|
|
mlx
|
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
|
|
|
if(NOT MLX_METAL_PATH)
|
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
|
endif()
|
|
|
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
|
|
|
target_compile_definitions(mlx
|
|
PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|