mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 01:51:18 +08:00

- Add SVDParams, JacobiRotation, and SVDConvergenceInfo structures - Create placeholder Metal kernel declarations for SVD operations - Add SVD kernel compilation to CMake build system - Update SVD::eval_gpu to dispatch to Metal implementation - Add basic input validation and error handling - Include placeholder kernel implementation for compilation This establishes the foundation for Metal SVD implementation. Actual algorithm implementation will follow in subsequent commits.
129 lines
5.0 KiB
CMake
129 lines
5.0 KiB
CMake
function(make_jit_source SRC_FILE)
|
|
# This function takes a metal header file, runs the C preprocessesor on it,
|
|
# and makes the processed contents available as a string in a C++ function
|
|
# mlx::core::metal::${SRC_NAME}()
|
|
#
|
|
# To use the function, declare it in jit/includes.h and include
|
|
# jit/includes.h.
|
|
#
|
|
# Additional arguments to this function are treated as dependencies in the
|
|
# Cmake build system.
|
|
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
|
add_custom_command(
|
|
OUTPUT jit/${SRC_NAME}.cpp
|
|
COMMAND
|
|
bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
|
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
|
${SRC_FILE}
|
|
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
|
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
|
add_dependencies(mlx ${SRC_NAME})
|
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
|
endfunction(make_jit_source)
|
|
|
|
make_jit_source(
|
|
utils
|
|
kernels/jit/bf16.h
|
|
kernels/metal_3_0/bf16.h
|
|
kernels/metal_3_1/bf16.h
|
|
kernels/bf16_math.h
|
|
kernels/complex.h
|
|
kernels/defines.h)
|
|
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
|
make_jit_source(binary_ops)
|
|
make_jit_source(ternary_ops)
|
|
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
|
make_jit_source(scatter kernels/indexing.h)
|
|
make_jit_source(gather kernels/indexing.h)
|
|
make_jit_source(gather_axis)
|
|
make_jit_source(scatter_axis)
|
|
make_jit_source(hadamard)
|
|
|
|
if(MLX_METAL_JIT)
|
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
|
|
make_jit_source(arange)
|
|
make_jit_source(copy)
|
|
make_jit_source(unary)
|
|
make_jit_source(binary)
|
|
make_jit_source(binary_two)
|
|
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
|
make_jit_source(logsumexp)
|
|
make_jit_source(ternary)
|
|
make_jit_source(softmax)
|
|
make_jit_source(scan)
|
|
make_jit_source(sort)
|
|
make_jit_source(svd)
|
|
make_jit_source(
|
|
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
|
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
|
make_jit_source(
|
|
steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
|
|
kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
|
|
kernels/steel/gemm/transforms.h)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
|
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
|
make_jit_source(
|
|
steel/conv/conv
|
|
kernels/steel/utils.h
|
|
kernels/steel/defines.h
|
|
kernels/steel/gemm/mma.h
|
|
kernels/steel/gemm/transforms.h
|
|
kernels/steel/conv/params.h
|
|
kernels/steel/conv/loader.h
|
|
kernels/steel/conv/loaders/loader_channel_l.h
|
|
kernels/steel/conv/loaders/loader_channel_n.h)
|
|
make_jit_source(steel/conv/kernels/steel_conv)
|
|
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
|
kernels/steel/conv/loaders/loader_general.h)
|
|
make_jit_source(quantized)
|
|
make_jit_source(gemv_masked)
|
|
else()
|
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
|
endif()
|
|
|
|
target_sources(
|
|
mlx
|
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
|
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
|
|
|
if(NOT MLX_METAL_PATH)
|
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
|
endif()
|
|
|
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
|
|
|
target_compile_definitions(mlx
|
|
PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|