mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Chore: add pre-commit hook for cmake (#1362)
* reset and lint * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,26 +1,24 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
if (MLX_BUILD_CPU)
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
@@ -28,17 +26,15 @@ endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
|
@@ -1,10 +1,8 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
@@ -7,72 +6,55 @@ else()
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${CLANG}
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
cpu_compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
)
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if (IOS)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
|
||||
)
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||
endif()
|
||||
|
@@ -1,99 +1,56 @@
|
||||
function(make_jit_source SRC_FILE)
|
||||
# This function takes a metal header file,
|
||||
# runs the C preprocessesor on it, and makes
|
||||
# the processed contents available as a string in a C++ function
|
||||
# This function takes a metal header file, runs the C preprocessesor on it,
|
||||
# and makes the processed contents available as a string in a C++ function
|
||||
# mlx::core::metal::${SRC_NAME}()
|
||||
#
|
||||
# To use the function, declare it in jit/includes.h and
|
||||
# include jit/includes.h.
|
||||
# To use the function, declare it in jit/includes.h and include
|
||||
# jit/includes.h.
|
||||
#
|
||||
# Additional arguments to this function are treated as dependencies
|
||||
# in the Cmake build system.
|
||||
# Additional arguments to this function are treated as dependencies in the
|
||||
# Cmake build system.
|
||||
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||
add_custom_command(
|
||||
OUTPUT jit/${SRC_NAME}.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
)
|
||||
OUTPUT jit/${SRC_NAME}.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/bf16.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h
|
||||
)
|
||||
make_jit_source(
|
||||
unary_ops
|
||||
kernels/erf.h
|
||||
kernels/expm1f.h
|
||||
)
|
||||
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(
|
||||
reduce_utils
|
||||
kernels/atomic.h
|
||||
kernels/reduction/ops.h
|
||||
)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
||||
)
|
||||
if(MLX_METAL_JIT)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
|
||||
make_jit_source(arange)
|
||||
make_jit_source(copy)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(
|
||||
fft
|
||||
kernels/fft/radix.h
|
||||
kernels/fft/readwrite.h
|
||||
)
|
||||
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
make_jit_source(sort)
|
||||
make_jit_source(
|
||||
reduce
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
kernels/reduction/reduce_init.h
|
||||
)
|
||||
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/gemm/loader.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/params.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
)
|
||||
steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
|
||||
kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
|
||||
kernels/steel/gemm/transforms.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
kernels/steel/defines.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
@@ -104,63 +61,51 @@ if (MLX_METAL_JIT)
|
||||
kernels/steel/conv/params.h
|
||||
kernels/steel/conv/loader.h
|
||||
kernels/steel/conv/loaders/loader_channel_l.h
|
||||
kernels/steel/conv/loaders/loader_channel_n.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
kernels/steel/conv/loaders/loader_channel_n.h)
|
||||
make_jit_source(steel/conv/kernels/steel_conv)
|
||||
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h)
|
||||
make_jit_source(quantized)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||
endif()
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
if(NOT MLX_METAL_PATH)
|
||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||
|
||||
target_compile_definitions(
|
||||
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
||||
target_compile_definitions(mlx
|
||||
PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
||||
|
@@ -1,38 +1,26 @@
|
||||
set(
|
||||
BASE_HEADERS
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
expm1f.h
|
||||
utils.h
|
||||
)
|
||||
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
-frecord-sources)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
${METAL_FLAGS}
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
VERBATIM
|
||||
)
|
||||
VERBATIM)
|
||||
endfunction(build_kernel_base)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
|
||||
set(KERNEL_AIR
|
||||
${TARGET}.air ${KERNEL_AIR}
|
||||
PARENT_SCOPE)
|
||||
endfunction(build_kernel)
|
||||
|
||||
build_kernel(arg_reduce)
|
||||
@@ -42,106 +30,66 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention
|
||||
scaled_dot_product_attention_params.h
|
||||
steel/defines.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils.h
|
||||
)
|
||||
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
build_kernel(binary_two binary_two.h)
|
||||
build_kernel(copy copy.h)
|
||||
build_kernel(
|
||||
fft
|
||||
fft.h
|
||||
fft/radix.h
|
||||
fft/readwrite.h
|
||||
)
|
||||
build_kernel(
|
||||
reduce
|
||||
atomic.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
build_kernel(
|
||||
quantized
|
||||
quantized.h
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
build_kernel(
|
||||
steel/conv/kernels/steel_conv
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_fused
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_splitk
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
build_kernel(binary_two binary_two.h)
|
||||
build_kernel(copy copy.h)
|
||||
build_kernel(fft fft.h fft/radix.h fft/readwrite.h)
|
||||
build_kernel(
|
||||
reduce
|
||||
atomic.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
VERBATIM
|
||||
)
|
||||
VERBATIM)
|
||||
|
||||
add_custom_target(
|
||||
mlx-metallib
|
||||
DEPENDS
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
)
|
||||
add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib)
|
||||
|
||||
add_dependencies(
|
||||
mlx
|
||||
mlx-metallib
|
||||
)
|
||||
add_dependencies(mlx mlx-metallib)
|
||||
|
||||
# Install metallib
|
||||
include(GNUInstallDirs)
|
||||
@@ -149,5 +97,4 @@ include(GNUInstallDirs)
|
||||
install(
|
||||
FILES ${MLX_METAL_PATH}/mlx.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
COMPONENT metallib
|
||||
)
|
||||
COMPONENT metallib)
|
||||
|
@@ -1,11 +1,9 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp)
|
||||
|
@@ -1,8 +1,6 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
|
||||
|
@@ -1,16 +1,8 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
||||
|
||||
if (MPI_FOUND AND MLX_BUILD_CPU)
|
||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
||||
endif()
|
||||
|
@@ -1,5 +1 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||
|
@@ -1,58 +1,32 @@
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
)
|
||||
|
||||
if (MLX_BUILD_SAFETENSORS)
|
||||
MESSAGE(STATUS "Downloading json")
|
||||
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
if(MLX_BUILD_SAFETENSORS)
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PRIVATE
|
||||
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
|
||||
)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp
|
||||
)
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_GGUF)
|
||||
MESSAGE(STATUS "Downloading gguflib")
|
||||
FetchContent_Declare(gguflib
|
||||
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
|
||||
GIT_TAG af7d88d808a7608a33723fba067036202910acb3
|
||||
)
|
||||
if(MLX_BUILD_GGUF)
|
||||
message(STATUS "Downloading gguflib")
|
||||
FetchContent_Declare(
|
||||
gguflib
|
||||
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
|
||||
GIT_TAG af7d88d808a7608a33723fba067036202910acb3)
|
||||
FetchContent_MakeAvailable(gguflib)
|
||||
target_include_directories(
|
||||
mlx PRIVATE
|
||||
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
|
||||
)
|
||||
add_library(
|
||||
gguflib STATIC
|
||||
${gguflib_SOURCE_DIR}/fp16.c
|
||||
${gguflib_SOURCE_DIR}/gguflib.c)
|
||||
target_include_directories(mlx
|
||||
PRIVATE $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>)
|
||||
add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c
|
||||
${gguflib_SOURCE_DIR}/gguflib.c)
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp)
|
||||
endif()
|
||||
|
||||
|
Reference in New Issue
Block a user