Chore: add pre-commit hook for cmake (#1362)

* reset and lint

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan 2024-09-16 20:53:01 +01:00 committed by GitHub
parent adcc88e208
commit 669c27140d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 402 additions and 607 deletions

View File

@ -14,3 +14,7 @@ repos:
- id: isort - id: isort
args: args:
- --profile=black - --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@ -29,18 +29,23 @@ endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}") message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF) set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC) if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR message(
"Building for x86_64 on macOS is not supported." FATAL_ERROR
" If you are on an Apple silicon system, check the build" "Building for x86_64 on macOS is not supported."
" documentation for possible fixes: " " If you are on an Apple silicon system, check the build"
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") " documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else() else()
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
@ -61,63 +66,59 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal) find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation) find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore) find_library(QUARTZ_LIB QuartzCore)
endif() endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF) set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL) elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources") message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
endif() endif()
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" execute_process(
OUTPUT_VARIABLE MACOS_VERSION COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0) if(${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif() endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
# Get the metal version # Get the metal version
execute_process( execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" COMMAND
OUTPUT_VARIABLE MLX_METAL_VERSION zsh "-c"
COMMAND_ERROR_IS_FATAL ANY) "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare( FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
metal_cpp
URL ${METAL_CPP_URL}
)
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> $<INSTALL_INTERFACE:include/metal_cpp>)
$<INSTALL_INTERFACE:include/metal_cpp> target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
)
target_link_libraries(
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}") add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif() endif()
if (MLX_BUILD_CPU) if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@ -129,32 +130,29 @@ if (MLX_BUILD_CPU)
# The blas shipped in macOS SDK is not supported, search homebrew for # The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead. # openblas instead.
set(BLA_VENDOR OpenBLAS) set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif() endif()
# Search and link with lapack. # Search and link with lapack.
find_package(LAPACK REQUIRED) find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND) if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed") message(FATAL_ERROR "Must have LAPACK installed")
endif() endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/include /usr/local/opt/openblas/include)
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version # List blas after lapack otherwise we may accidentally incldue an old
# of lapack.h from the include dirs of blas. # version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND) if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed") message(FATAL_ERROR "Must have BLAS installed")
endif() endif()
# TODO find a cleaner way to do this # TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
/usr/include $ENV{BLAS_HOME}/include)
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
@ -165,103 +163,94 @@ else()
endif() endif()
find_package(MPI) find_package(MPI)
if (MPI_FOUND) if(MPI_FOUND)
execute_process( execute_process(
COMMAND zsh "-c" "mpirun --version" COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET ERROR_QUIET)
) if(${MPI_VERSION} MATCHES ".*Open MPI.*")
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "") elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE) set(MPI_FOUND FALSE)
message( message(
WARNING WARNING "MPI found but mpirun is not available. Building without MPI.")
"MPI found but mpirun is not available. Building without MPI."
)
else() else()
set(MPI_FOUND FALSE) set(MPI_FOUND FALSE)
message( message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
WARNING endif()
"MPI which is not OpenMPI found. Building without MPI."
)
endif()
endif() endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(
mlx mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
PUBLIC $<INSTALL_INTERFACE:include>)
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 GIT_TAG 10.2.1
EXCLUDE_FROM_ALL EXCLUDE_FROM_ALL)
)
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only) target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if (MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()
if (MLX_BUILD_TESTS) if(MLX_BUILD_TESTS)
include(CTest) include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif() endif()
if (MLX_BUILD_EXAMPLES) if(MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif() endif()
if (MLX_BUILD_BENCHMARKS) if(MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif() endif()
# ----------------------------- Installation ----------------------------- # ----------------------------- Installation -----------------------------
include(GNUInstallDirs) include(GNUInstallDirs)
# Install library # Install library
install( install(
TARGETS mlx TARGETS mlx
EXPORT MLXTargets EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} INCLUDES
) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# Install headers # Install headers
install( install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers COMPONENT headers
FILES_MATCHING PATTERN "*.h" FILES_MATCHING
) PATTERN "*.h")
# Install metal dependencies # Install metal dependencies
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
# Install metal cpp # Install metal cpp
install( install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/ DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source COMPONENT metal_cpp_source)
)
endif() endif()
@ -273,31 +262,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install( install(
EXPORT MLXTargets EXPORT MLXTargets
FILE MLXTargets.cmake FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
)
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
write_basic_package_version_file( write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION} VERSION ${MLX_VERSION})
)
configure_package_config_file( configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
) MLX_CMAKE_INSTALL_MODULE_DIR)
install( install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install( install(DIRECTORY ${CMAKE_MODULE_PATH}/
DIRECTORY ${CMAKE_MODULE_PATH}/ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@ -1,56 +1,41 @@
include(CMakeParseArguments) include(CMakeParseArguments)
############################################################################### # ##############################################################################
# Build metal library # Build metal library
# #
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
# #
# Args: # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# TARGET: Custom target to be added for the metal library # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# TITLE: Name of the .metallib # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib # files (like headers)
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
# #
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments( cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output # Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET} OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal COMMAND
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>" xcrun -sdk macosx metal
${MTLLIB_COMPILE_OPTIONS} "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_SOURCES} ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib" COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM VERBATIM)
)
# Add metallib custom target # Add metallib custom target
add_custom_target( add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib) endmacro(mlx_build_metallib)

View File

@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED)
add_library(mlx_ext) add_library(mlx_ext)
# Add sources # Add sources
target_sources( target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers # Add include headers
target_include_directories( target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx # Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx) target_link_libraries(mlx_ext PUBLIC mlx)
@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib # Build metallib
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
mlx_build_metallib( mlx_build_metallib(
TARGET mlx_ext_metallib TARGET
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib mlx_ext_metallib
) TITLE
mlx_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
add_dependencies(mlx_ext mlx_ext_metallib)
endif() endif()
# ----------------------------- Python Bindings ----------------------------- # ----------------------------- Python Bindings -----------------------------
nanobind_add_module( nanobind_add_module(
_ext _ext
NB_STATIC STABLE_ABI LTO NOMINSIZE NB_STATIC
NB_DOMAIN mlx STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp LTO
) NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)

View File

@ -1,26 +1,24 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
if (MLX_BUILD_CPU) if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
@ -28,17 +26,15 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU) elseif(MLX_BUILD_CPU)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
endif() endif()
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)

View File

@ -1,10 +1,8 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@ -1,5 +1,4 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER}) set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE) set(CLANG TRUE)
else() else()
@ -7,72 +6,55 @@ else()
endif() endif()
add_custom_command( add_custom_command(
OUTPUT compiled_preamble.cpp OUTPUT compiled_preamble.cpp
COMMAND /bin/bash COMMAND
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${COMPILER} ${PROJECT_SOURCE_DIR} ${CLANG}
${PROJECT_SOURCE_DIR} DEPENDS make_compiled_preamble.sh
${CLANG} compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
DEPENDS make_compiled_preamble.sh add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble) add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
if (IOS) if(IOS)
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif() endif()

View File

@ -1,99 +1,56 @@
function(make_jit_source SRC_FILE) function(make_jit_source SRC_FILE)
# This function takes a metal header file, # This function takes a metal header file, runs the C preprocessesor on it,
# runs the C preprocessesor on it, and makes # and makes the processed contents available as a string in a C++ function
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}() # mlx::core::metal::${SRC_NAME}()
# #
# To use the function, declare it in jit/includes.h and # To use the function, declare it in jit/includes.h and include
# include jit/includes.h. # jit/includes.h.
# #
# Additional arguments to this function are treated as dependencies # Additional arguments to this function are treated as dependencies in the
# in the Cmake build system. # Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME) get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command( add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash COMMAND
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${CMAKE_C_COMPILER} ${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
${PROJECT_SOURCE_DIR} DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
${SRC_FILE}
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME}) add_dependencies(mlx ${SRC_NAME})
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source) endfunction(make_jit_source)
make_jit_source( make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
utils make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)
make_jit_source( make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
make_jit_source(scatter) make_jit_source(scatter)
make_jit_source(gather) make_jit_source(gather)
make_jit_source(hadamard) make_jit_source(hadamard)
if (MLX_METAL_JIT) if(MLX_METAL_JIT)
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange) make_jit_source(arange)
make_jit_source(copy) make_jit_source(copy)
make_jit_source(unary) make_jit_source(unary)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source( make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
make_jit_source(sort) make_jit_source(sort)
make_jit_source( make_jit_source(
reduce reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
kernels/reduction/reduce_all.h kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
kernels/reduction/reduce_init.h
)
make_jit_source( make_jit_source(
steel/gemm/gemm steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
kernels/steel/utils.h kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
kernels/steel/gemm/loader.h kernels/steel/gemm/transforms.h)
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source( make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@ -104,63 +61,51 @@ if (MLX_METAL_JIT)
kernels/steel/conv/params.h kernels/steel/conv/params.h
kernels/steel/conv/loader.h kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h kernels/steel/conv/loaders/loader_channel_n.h)
) make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source( make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
steel/conv/kernels/steel_conv kernels/steel/conv/loaders/loader_general.h)
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized) make_jit_source(quantized)
make_jit_source(gemv_masked) make_jit_source(gemv_masked)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif() endif()
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)
if (NOT MLX_METAL_PATH) if(NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions( target_compile_definitions(mlx
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")

View File

@ -1,38 +1,26 @@
set( set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
BASE_HEADERS
bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h
)
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
-gline-tables-only
-frecord-sources)
endif() endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
${METAL_FLAGS} -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air" COMMENT "Building ${TARGET}.air"
VERBATIM VERBATIM)
)
endfunction(build_kernel_base) endfunction(build_kernel_base)
function(build_kernel KERNEL) function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
cmake_path(GET KERNEL STEM TARGET) cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}") build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE) set(KERNEL_AIR
${TARGET}.air ${KERNEL_AIR}
PARENT_SCOPE)
endfunction(build_kernel) endfunction(build_kernel)
build_kernel(arg_reduce) build_kernel(arg_reduce)
@ -42,106 +30,66 @@ build_kernel(layer_norm)
build_kernel(random) build_kernel(random)
build_kernel(rms_norm) build_kernel(rms_norm)
build_kernel(rope) build_kernel(rope)
build_kernel( build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
scaled_dot_product_attention steel/defines.h steel/gemm/transforms.h steel/utils.h)
scaled_dot_product_attention_params.h
steel/defines.h
steel/gemm/transforms.h
steel/utils.h
)
set( set(STEEL_HEADERS
STEEL_HEADERS steel/defines.h
steel/defines.h steel/utils.h
steel/utils.h steel/conv/conv.h
steel/conv/conv.h steel/conv/loader.h
steel/conv/loader.h steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_l.h steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_channel_n.h steel/conv/loaders/loader_general.h
steel/conv/loaders/loader_general.h steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv.h steel/conv/kernels/steel_conv_general.h
steel/conv/kernels/steel_conv_general.h steel/gemm/gemm.h
steel/gemm/gemm.h steel/gemm/mma.h
steel/gemm/mma.h steel/gemm/loader.h
steel/gemm/loader.h steel/gemm/transforms.h
steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_splitk.h)
steel/gemm/kernels/steel_gemm_splitk.h
)
if (NOT MLX_METAL_JIT) if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h) build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h) build_kernel(binary binary.h binary_ops.h)
build_kernel(binary_two binary_two.h) build_kernel(binary_two binary_two.h)
build_kernel(copy copy.h) build_kernel(copy copy.h)
build_kernel( build_kernel(fft fft.h fft/radix.h fft/readwrite.h)
fft build_kernel(
fft.h reduce
fft/radix.h atomic.h
fft/readwrite.h reduction/ops.h
) reduction/reduce_init.h
build_kernel( reduction/reduce_all.h
reduce reduction/reduce_col.h
atomic.h reduction/reduce_row.h)
reduction/ops.h build_kernel(quantized quantized.h ${STEEL_HEADERS})
reduction/reduce_init.h build_kernel(scan scan.h)
reduction/reduce_all.h build_kernel(softmax softmax.h)
reduction/reduce_col.h build_kernel(sort sort.h)
reduction/reduce_row.h build_kernel(ternary ternary.h ternary_ops.h)
) build_kernel(unary unary.h unary_ops.h)
build_kernel( build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
quantized build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
quantized.h build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
${STEEL_HEADERS} build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(gemv_masked steel/utils.h)
build_kernel(softmax softmax.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(
steel/conv/kernels/steel_conv
${STEEL_HEADERS}
)
build_kernel(
steel/conv/kernels/steel_conv_general
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_fused
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_masked
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_splitk
${STEEL_HEADERS}
)
build_kernel(gemv_masked steel/utils.h)
endif() endif()
add_custom_command( add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR} DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib" COMMENT "Building mlx.metallib"
VERBATIM VERBATIM)
)
add_custom_target( add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib)
mlx-metallib
DEPENDS
${MLX_METAL_PATH}/mlx.metallib
)
add_dependencies( add_dependencies(mlx mlx-metallib)
mlx
mlx-metallib
)
# Install metallib # Install metallib
include(GNUInstallDirs) include(GNUInstallDirs)
@ -149,5 +97,4 @@ include(GNUInstallDirs)
install( install(
FILES ${MLX_METAL_PATH}/mlx.metallib FILES ${MLX_METAL_PATH}/mlx.metallib
DESTINATION ${CMAKE_INSTALL_LIBDIR} DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT metallib COMPONENT metallib)
)

View File

@ -1,11 +1,9 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
)

View File

@ -1,8 +1,6 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
)

View File

@ -1,16 +1,8 @@
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
mlx ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
)
if (MPI_FOUND AND MLX_BUILD_CPU) if(MPI_FOUND AND MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
)
endif() endif()

View File

@ -1,5 +1 @@
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
)

View File

@ -1,58 +1,32 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)
target_sources( if(MLX_BUILD_SAFETENSORS)
mlx message(STATUS "Downloading json")
PRIVATE FetchContent_Declare(
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp json
) URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
if (MLX_BUILD_SAFETENSORS)
MESSAGE(STATUS "Downloading json")
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json) FetchContent_MakeAvailable(json)
target_include_directories( target_include_directories(
mlx PRIVATE mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann> target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)
)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp
)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp
)
endif() endif()
if (MLX_BUILD_GGUF) if(MLX_BUILD_GGUF)
MESSAGE(STATUS "Downloading gguflib") message(STATUS "Downloading gguflib")
FetchContent_Declare(gguflib FetchContent_Declare(
GIT_REPOSITORY https://github.com/antirez/gguf-tools/ gguflib
GIT_TAG af7d88d808a7608a33723fba067036202910acb3 GIT_REPOSITORY https://github.com/antirez/gguf-tools/
) GIT_TAG af7d88d808a7608a33723fba067036202910acb3)
FetchContent_MakeAvailable(gguflib) FetchContent_MakeAvailable(gguflib)
target_include_directories( target_include_directories(mlx
mlx PRIVATE PRIVATE $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>)
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}> add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c
) ${gguflib_SOURCE_DIR}/gguflib.c)
add_library(
gguflib STATIC
${gguflib_SOURCE_DIR}/fp16.c
${gguflib_SOURCE_DIR}/gguflib.c)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
mlx ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp
)
endif() endif()

View File

@ -1,7 +1,11 @@
nanobind_add_module( nanobind_add_module(
core core
NB_STATIC STABLE_ABI LTO NOMINSIZE NB_STATIC
NB_DOMAIN mlx STABLE_ABI
LTO
NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
@ -19,19 +23,14 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
)
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
endif() endif()
set_target_properties( set_target_properties(core PROPERTIES LIBRARY_OUTPUT_DIRECTORY
core ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}
)
target_link_libraries(core PRIVATE mlx) target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})

View File

@ -1,43 +1,39 @@
FetchContent_Declare( FetchContent_Declare(
doctest doctest
GIT_REPOSITORY "https://github.com/onqtam/doctest" GIT_REPOSITORY "https://github.com/onqtam/doctest"
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2" GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
)
FetchContent_MakeAvailable(doctest) FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
set( set(METAL_TEST_SOURCES metal_tests.cpp)
METAL_TEST_SOURCES
metal_tests.cpp
)
endif() endif()
include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)
target_sources(tests PRIVATE target_sources(
allocator_tests.cpp tests
array_tests.cpp PRIVATE allocator_tests.cpp
arg_reduce_tests.cpp array_tests.cpp
autograd_tests.cpp arg_reduce_tests.cpp
blas_tests.cpp autograd_tests.cpp
compile_tests.cpp blas_tests.cpp
custom_vjp_tests.cpp compile_tests.cpp
creations_tests.cpp custom_vjp_tests.cpp
device_tests.cpp creations_tests.cpp
einsum_tests.cpp device_tests.cpp
eval_tests.cpp einsum_tests.cpp
fft_tests.cpp eval_tests.cpp
load_tests.cpp fft_tests.cpp
ops_tests.cpp load_tests.cpp
random_tests.cpp ops_tests.cpp
scheduler_tests.cpp random_tests.cpp
utils_tests.cpp scheduler_tests.cpp
vmap_tests.cpp utils_tests.cpp
linalg_tests.cpp vmap_tests.cpp
${METAL_TEST_SOURCES} linalg_tests.cpp
) ${METAL_TEST_SOURCES})
target_link_libraries(tests PRIVATE mlx doctest) target_link_libraries(tests PRIVATE mlx doctest)
doctest_discover_tests(tests) doctest_discover_tests(tests)