mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix kernel deps to reduce build times (#1205)
This commit is contained in:
parent
dd7d8e5e29
commit
de2b9e7d0a
@ -1,68 +1,12 @@
|
|||||||
set(
|
set(
|
||||||
HEADERS
|
BASE_HEADERS
|
||||||
bf16.h
|
bf16.h
|
||||||
bf16_math.h
|
bf16_math.h
|
||||||
complex.h
|
complex.h
|
||||||
defines.h
|
defines.h
|
||||||
utils.h
|
utils.h
|
||||||
steel/conv/params.h
|
|
||||||
)
|
)
|
||||||
|
|
||||||
set(
|
|
||||||
KERNELS
|
|
||||||
"arg_reduce"
|
|
||||||
"conv"
|
|
||||||
"gemv"
|
|
||||||
"random"
|
|
||||||
"rms_norm"
|
|
||||||
"layer_norm"
|
|
||||||
"rope"
|
|
||||||
"scaled_dot_product_attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (NOT MLX_METAL_JIT)
|
|
||||||
set(
|
|
||||||
KERNELS
|
|
||||||
${KERNELS}
|
|
||||||
"arange"
|
|
||||||
"binary"
|
|
||||||
"binary_two"
|
|
||||||
"unary"
|
|
||||||
"ternary"
|
|
||||||
"copy"
|
|
||||||
"fft"
|
|
||||||
"quantized"
|
|
||||||
"softmax"
|
|
||||||
"sort"
|
|
||||||
"scan"
|
|
||||||
"reduce"
|
|
||||||
)
|
|
||||||
set(
|
|
||||||
HEADERS
|
|
||||||
${HEADERS}
|
|
||||||
atomic.h
|
|
||||||
arange.h
|
|
||||||
unary_ops.h
|
|
||||||
unary.h
|
|
||||||
binary_ops.h
|
|
||||||
binary.h
|
|
||||||
ternary.h
|
|
||||||
copy.h
|
|
||||||
fft.h
|
|
||||||
fft/radix.h
|
|
||||||
fft/readwrite.h
|
|
||||||
quantized.h
|
|
||||||
softmax.h
|
|
||||||
sort.h
|
|
||||||
scan.h
|
|
||||||
reduction/ops.h
|
|
||||||
reduction/reduce_init.h
|
|
||||||
reduction/reduce_all.h
|
|
||||||
reduction/reduce_col.h
|
|
||||||
reduction/reduce_row.h
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
@ -76,7 +20,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
|||||||
-c ${SRCFILE}
|
-c ${SRCFILE}
|
||||||
-I${PROJECT_SOURCE_DIR}
|
-I${PROJECT_SOURCE_DIR}
|
||||||
-o ${TARGET}.air
|
-o ${TARGET}.air
|
||||||
DEPENDS ${SRCFILE} ${DEPS}
|
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||||
OUTPUT ${TARGET}.air
|
OUTPUT ${TARGET}.air
|
||||||
COMMENT "Building ${TARGET}.air"
|
COMMENT "Building ${TARGET}.air"
|
||||||
VERBATIM
|
VERBATIM
|
||||||
@ -85,23 +29,26 @@ endfunction(build_kernel_base)
|
|||||||
|
|
||||||
function(build_kernel KERNEL)
|
function(build_kernel KERNEL)
|
||||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
|
cmake_path(GET KERNEL STEM TARGET)
|
||||||
|
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
|
||||||
|
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
|
||||||
endfunction(build_kernel)
|
endfunction(build_kernel)
|
||||||
|
|
||||||
foreach(KERNEL ${KERNELS})
|
build_kernel(arg_reduce)
|
||||||
build_kernel(${KERNEL})
|
build_kernel(conv steel/conv/params.h)
|
||||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
build_kernel(gemv steel/utils.h)
|
||||||
endforeach()
|
build_kernel(layer_norm)
|
||||||
|
build_kernel(random)
|
||||||
if (NOT MLX_METAL_JIT)
|
build_kernel(rms_norm)
|
||||||
set(
|
build_kernel(rope)
|
||||||
STEEL_KERNELS
|
build_kernel(
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
scaled_dot_product_attention
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
scaled_dot_product_attention_params.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
steel/defines.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
steel/gemm/transforms.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
steel/utils.h
|
||||||
)
|
)
|
||||||
|
|
||||||
set(
|
set(
|
||||||
STEEL_HEADERS
|
STEEL_HEADERS
|
||||||
steel/defines.h
|
steel/defines.h
|
||||||
@ -121,13 +68,60 @@ if (NOT MLX_METAL_JIT)
|
|||||||
steel/gemm/kernels/steel_gemm_masked.h
|
steel/gemm/kernels/steel_gemm_masked.h
|
||||||
steel/gemm/kernels/steel_gemm_splitk.h
|
steel/gemm/kernels/steel_gemm_splitk.h
|
||||||
)
|
)
|
||||||
foreach(KERNEL ${STEEL_KERNELS})
|
|
||||||
cmake_path(GET KERNEL STEM TARGET)
|
if (NOT MLX_METAL_JIT)
|
||||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
build_kernel(arange arange.h)
|
||||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
build_kernel(binary binary.h binary_ops.h)
|
||||||
endforeach()
|
build_kernel(binary_two binary_two.h)
|
||||||
|
build_kernel(copy copy.h)
|
||||||
|
build_kernel(
|
||||||
|
fft
|
||||||
|
fft.h
|
||||||
|
fft/radix.h
|
||||||
|
fft/readwrite.h
|
||||||
|
)
|
||||||
|
build_kernel(
|
||||||
|
reduce
|
||||||
|
atomic.h
|
||||||
|
reduction/ops.h
|
||||||
|
reduction/reduce_init.h
|
||||||
|
reduction/reduce_all.h
|
||||||
|
reduction/reduce_col.h
|
||||||
|
reduction/reduce_row.h
|
||||||
|
)
|
||||||
|
build_kernel(
|
||||||
|
quantized
|
||||||
|
quantized.h
|
||||||
|
${STEEL_HEADERS}
|
||||||
|
)
|
||||||
|
build_kernel(scan scan.h)
|
||||||
|
build_kernel(softmax softmax.h)
|
||||||
|
build_kernel(sort sort.h)
|
||||||
|
build_kernel(ternary ternary.h ternary_ops.h)
|
||||||
|
build_kernel(unary unary.h unary_ops.h)
|
||||||
|
build_kernel(
|
||||||
|
steel/conv/kernels/steel_conv
|
||||||
|
${STEEL_HEADERS}
|
||||||
|
)
|
||||||
|
build_kernel(
|
||||||
|
steel/conv/kernels/steel_conv_general
|
||||||
|
${STEEL_HEADERS}
|
||||||
|
)
|
||||||
|
build_kernel(
|
||||||
|
steel/gemm/kernels/steel_gemm_fused
|
||||||
|
${STEEL_HEADERS}
|
||||||
|
)
|
||||||
|
build_kernel(
|
||||||
|
steel/gemm/kernels/steel_gemm_masked
|
||||||
|
${STEEL_HEADERS}
|
||||||
|
)
|
||||||
|
build_kernel(
|
||||||
|
steel/gemm/kernels/steel_gemm_splitk
|
||||||
|
${STEEL_HEADERS}
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
Loading…
Reference in New Issue
Block a user