JIT compile option for binary minimization (#1091)

* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
This commit is contained in:
Awni Hannun
2024-05-22 12:57:13 -07:00
committed by GitHub
parent d568c7ee36
commit 226748b3e7
56 changed files with 3153 additions and 2605 deletions

View File

@@ -3,13 +3,8 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
@@ -17,10 +12,7 @@ set(
KERNELS
"arange"
"arg_reduce"
"binary"
"binary_two"
"conv"
"copy"
"fft"
"gemv"
"quantized"
@@ -32,12 +24,30 @@ set(
"scaled_dot_product_attention"
"softmax"
"sort"
"ternary"
"unary"
"gather"
"scatter"
)
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"binary"
"binary_two"
"unary"
"ternary"
"copy"
)
set(
HEADERS
${HEADERS}
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
)
endif()
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
if(MLX_METAL_DEBUG)