More jitting (#1132)

* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
This commit is contained in:
Awni Hannun
2024-05-23 16:23:44 -07:00
committed by GitHub
parent 9401507336
commit 0189ab6ab6
41 changed files with 2377 additions and 1846 deletions

View File

@@ -8,9 +8,9 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
set(
KERNELS
"arange"
"arg_reduce"
"conv"
"fft"
@@ -20,31 +20,42 @@ set(
"rms_norm"
"layer_norm"
"rope"
"scan"
"scaled_dot_product_attention"
"softmax"
"sort"
)
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"arange"
"binary"
"binary_two"
"unary"
"ternary"
"copy"
"softmax"
"sort"
"scan"
"reduce"
)
set(
HEADERS
${HEADERS}
arange.h
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
softmax.h
sort.h
scan.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
endif()
@@ -87,15 +98,6 @@ foreach(KERNEL ${STEEL_KERNELS})
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
foreach(KERNEL ${REDUCE_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib