From de2b9e7d0a78439211d1433d96a24f0c4f803b5c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 12 Jun 2024 11:17:39 -0700 Subject: [PATCH] Fix kernel deps to reduce build times (#1205) --- mlx/backend/metal/kernels/CMakeLists.txt | 184 +++++++++++------------ 1 file changed, 89 insertions(+), 95 deletions(-) diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 81751e917..5158d5812 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -1,68 +1,12 @@ set( - HEADERS + BASE_HEADERS bf16.h bf16_math.h complex.h defines.h utils.h - steel/conv/params.h ) -set( - KERNELS - "arg_reduce" - "conv" - "gemv" - "random" - "rms_norm" - "layer_norm" - "rope" - "scaled_dot_product_attention" -) - -if (NOT MLX_METAL_JIT) -set( - KERNELS - ${KERNELS} - "arange" - "binary" - "binary_two" - "unary" - "ternary" - "copy" - "fft" - "quantized" - "softmax" - "sort" - "scan" - "reduce" -) -set( - HEADERS - ${HEADERS} - atomic.h - arange.h - unary_ops.h - unary.h - binary_ops.h - binary.h - ternary.h - copy.h - fft.h - fft/radix.h - fft/readwrite.h - quantized.h - softmax.h - sort.h - scan.h - reduction/ops.h - reduction/reduce_init.h - reduction/reduce_all.h - reduction/reduce_col.h - reduction/reduce_row.h -) -endif() - function(build_kernel_base TARGET SRCFILE DEPS) set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) if(MLX_METAL_DEBUG) @@ -76,7 +20,7 @@ function(build_kernel_base TARGET SRCFILE DEPS) -c ${SRCFILE} -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air - DEPENDS ${SRCFILE} ${DEPS} + DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} OUTPUT ${TARGET}.air COMMENT "Building ${TARGET}.air" VERBATIM @@ -85,49 +29,99 @@ endfunction(build_kernel_base) function(build_kernel KERNEL) set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) - build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}") + cmake_path(GET KERNEL STEM TARGET) + build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}") + set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE) endfunction(build_kernel) -foreach(KERNEL ${KERNELS}) - build_kernel(${KERNEL}) - set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR}) -endforeach() +build_kernel(arg_reduce) +build_kernel(conv steel/conv/params.h) +build_kernel(gemv steel/utils.h) +build_kernel(layer_norm) +build_kernel(random) +build_kernel(rms_norm) +build_kernel(rope) +build_kernel( + scaled_dot_product_attention + scaled_dot_product_attention_params.h + steel/defines.h + steel/gemm/transforms.h + steel/utils.h +) + +set( + STEEL_HEADERS + steel/defines.h + steel/utils.h + steel/conv/conv.h + steel/conv/loader.h + steel/conv/loaders/loader_channel_l.h + steel/conv/loaders/loader_channel_n.h + steel/conv/loaders/loader_general.h + steel/conv/kernels/steel_conv.h + steel/conv/kernels/steel_conv_general.h + steel/gemm/gemm.h + steel/gemm/mma.h + steel/gemm/loader.h + steel/gemm/transforms.h + steel/gemm/kernels/steel_gemm_fused.h + steel/gemm/kernels/steel_gemm_masked.h + steel/gemm/kernels/steel_gemm_splitk.h +) if (NOT MLX_METAL_JIT) - set( - STEEL_KERNELS - ${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal - ${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal - ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal - ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal - ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal - ) - set( - STEEL_HEADERS - steel/defines.h - steel/utils.h - steel/conv/conv.h - steel/conv/loader.h - steel/conv/loaders/loader_channel_l.h - steel/conv/loaders/loader_channel_n.h - steel/conv/loaders/loader_general.h - steel/conv/kernels/steel_conv.h - steel/conv/kernels/steel_conv_general.h - steel/gemm/gemm.h - steel/gemm/mma.h - steel/gemm/loader.h - steel/gemm/transforms.h - steel/gemm/kernels/steel_gemm_fused.h - steel/gemm/kernels/steel_gemm_masked.h - steel/gemm/kernels/steel_gemm_splitk.h - ) - foreach(KERNEL ${STEEL_KERNELS}) - cmake_path(GET KERNEL STEM TARGET) - build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}") - set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) - endforeach() +build_kernel(arange arange.h) +build_kernel(binary binary.h binary_ops.h) +build_kernel(binary_two binary_two.h) +build_kernel(copy copy.h) +build_kernel( + fft + fft.h + fft/radix.h + fft/readwrite.h +) +build_kernel( + reduce + atomic.h + reduction/ops.h + reduction/reduce_init.h + reduction/reduce_all.h + reduction/reduce_col.h + reduction/reduce_row.h +) +build_kernel( + quantized + quantized.h + ${STEEL_HEADERS} +) +build_kernel(scan scan.h) +build_kernel(softmax softmax.h) +build_kernel(sort sort.h) +build_kernel(ternary ternary.h ternary_ops.h) +build_kernel(unary unary.h unary_ops.h) +build_kernel( + steel/conv/kernels/steel_conv + ${STEEL_HEADERS} +) +build_kernel( + steel/conv/kernels/steel_conv_general + ${STEEL_HEADERS} +) +build_kernel( + steel/gemm/kernels/steel_gemm_fused + ${STEEL_HEADERS} +) +build_kernel( + steel/gemm/kernels/steel_gemm_masked + ${STEEL_HEADERS} +) +build_kernel( + steel/gemm/kernels/steel_gemm_splitk + ${STEEL_HEADERS} +) endif() + add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib