mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix kernel deps to reduce build times (#1205)
This commit is contained in:
parent
dd7d8e5e29
commit
de2b9e7d0a
@ -1,68 +1,12 @@
|
||||
set(
|
||||
HEADERS
|
||||
BASE_HEADERS
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
utils.h
|
||||
steel/conv/params.h
|
||||
)
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arg_reduce"
|
||||
"conv"
|
||||
"gemv"
|
||||
"random"
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scaled_dot_product_attention"
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
KERNELS
|
||||
${KERNELS}
|
||||
"arange"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"unary"
|
||||
"ternary"
|
||||
"copy"
|
||||
"fft"
|
||||
"quantized"
|
||||
"softmax"
|
||||
"sort"
|
||||
"scan"
|
||||
"reduce"
|
||||
)
|
||||
set(
|
||||
HEADERS
|
||||
${HEADERS}
|
||||
atomic.h
|
||||
arange.h
|
||||
unary_ops.h
|
||||
unary.h
|
||||
binary_ops.h
|
||||
binary.h
|
||||
ternary.h
|
||||
copy.h
|
||||
fft.h
|
||||
fft/radix.h
|
||||
fft/readwrite.h
|
||||
quantized.h
|
||||
softmax.h
|
||||
sort.h
|
||||
scan.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
endif()
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||
if(MLX_METAL_DEBUG)
|
||||
@ -76,7 +20,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
VERBATIM
|
||||
@ -85,24 +29,27 @@ endfunction(build_kernel_base)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
|
||||
endfunction(build_kernel)
|
||||
|
||||
foreach(KERNEL ${KERNELS})
|
||||
build_kernel(${KERNEL})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
build_kernel(arg_reduce)
|
||||
build_kernel(conv steel/conv/params.h)
|
||||
build_kernel(gemv steel/utils.h)
|
||||
build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention
|
||||
scaled_dot_product_attention_params.h
|
||||
steel/defines.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils.h
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
STEEL_KERNELS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
||||
)
|
||||
set(
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
@ -120,14 +67,61 @@ if (NOT MLX_METAL_JIT)
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
build_kernel(binary_two binary_two.h)
|
||||
build_kernel(copy copy.h)
|
||||
build_kernel(
|
||||
fft
|
||||
fft.h
|
||||
fft/radix.h
|
||||
fft/readwrite.h
|
||||
)
|
||||
build_kernel(
|
||||
reduce
|
||||
atomic.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
build_kernel(
|
||||
quantized
|
||||
quantized.h
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
build_kernel(
|
||||
steel/conv/kernels/steel_conv
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_fused
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_splitk
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
Loading…
Reference in New Issue
Block a user