diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66221d799..c2b6f1f34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,7 @@ repos: - id: isort args: - --profile=black +- repo: https://github.com/cheshirekow/cmake-format-precommit + rev: v0.6.13 + hooks: + - id: cmake-format diff --git a/CMakeLists.txt b/CMakeLists.txt index 1af9665bf..75b543566 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,18 +29,23 @@ endif() # --------------------- Processor tests ------------------------- -message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}") +message( + STATUS + "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}" +) set(MLX_BUILD_ARM OFF) -if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") +if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(NOT MLX_ENABLE_X64_MAC) - message(FATAL_ERROR - "Building for x86_64 on macOS is not supported." - " If you are on an Apple silicon system, check the build" - " documentation for possible fixes: " - "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") + message( + FATAL_ERROR + "Building for x86_64 on macOS is not supported." + " If you are on an Apple silicon system, check the build" + " documentation for possible fixes: " + "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source" + ) else() message(WARNING "Building for x86_64 arch is not officially supported.") endif() @@ -61,63 +66,59 @@ cmake_policy(SET CMP0135 NEW) add_library(mlx) -if (MLX_BUILD_METAL) +if(MLX_BUILD_METAL) find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) find_library(QUARTZ_LIB QuartzCore) endif() -if (MLX_BUILD_METAL AND NOT METAL_LIB) +if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) set(MLX_METAL_DEBUG OFF) -elseif (MLX_BUILD_METAL) +elseif(MLX_BUILD_METAL) message(STATUS "Building METAL sources") - if (MLX_METAL_DEBUG) + if(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG) endif() # Throw an error if xcrun not found - execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" - OUTPUT_VARIABLE MACOS_VERSION - COMMAND_ERROR_IS_FATAL ANY) + execute_process( + COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" + OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY) - if (${MACOS_VERSION} LESS 14.0) - message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) + if(${MACOS_VERSION} LESS 14.0) + message( + FATAL_ERROR + "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON") endif() message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") - set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) + set(METAL_CPP_URL + https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip + ) # Get the metal version execute_process( - COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" - OUTPUT_VARIABLE MLX_METAL_VERSION - COMMAND_ERROR_IS_FATAL ANY) + COMMAND + zsh "-c" + "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" + OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) - FetchContent_Declare( - metal_cpp - URL ${METAL_CPP_URL} - ) + FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_MakeAvailable(metal_cpp) target_include_directories( - mlx PUBLIC - $ - $ - ) - target_link_libraries( - mlx PUBLIC - ${METAL_LIB} - ${FOUNDATION_LIB} - ${QUARTZ_LIB}) + mlx PUBLIC $ + $) + target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}") endif() -if (MLX_BUILD_CPU) +if(MLX_BUILD_CPU) find_library(ACCELERATE_LIBRARY Accelerate) - if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) + if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") set(MLX_BUILD_ACCELERATE ON) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) @@ -129,32 +130,29 @@ if (MLX_BUILD_CPU) # The blas shipped in macOS SDK is not supported, search homebrew for # openblas instead. set(BLA_VENDOR OpenBLAS) - set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") + set(LAPACK_ROOT + "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") endif() # Search and link with lapack. find_package(LAPACK REQUIRED) - if (NOT LAPACK_FOUND) + if(NOT LAPACK_FOUND) message(FATAL_ERROR "Must have LAPACK installed") endif() - find_path(LAPACK_INCLUDE_DIRS lapacke.h - /usr/include - /usr/local/include - /usr/local/opt/openblas/include) + find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include + /usr/local/opt/openblas/include) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES}) - # List blas after lapack otherwise we may accidentally incldue an old version - # of lapack.h from the include dirs of blas. + # List blas after lapack otherwise we may accidentally incldue an old + # version of lapack.h from the include dirs of blas. find_package(BLAS REQUIRED) - if (NOT BLAS_FOUND) + if(NOT BLAS_FOUND) message(FATAL_ERROR "Must have BLAS installed") endif() # TODO find a cleaner way to do this - find_path(BLAS_INCLUDE_DIRS cblas.h - /usr/include - /usr/local/include - $ENV{BLAS_HOME}/include) + find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include + $ENV{BLAS_HOME}/include) message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) @@ -165,103 +163,94 @@ else() endif() find_package(MPI) -if (MPI_FOUND) +if(MPI_FOUND) execute_process( COMMAND zsh "-c" "mpirun --version" OUTPUT_VARIABLE MPI_VERSION - ERROR_QUIET - ) - if (${MPI_VERSION} MATCHES ".*Open MPI.*") + ERROR_QUIET) + if(${MPI_VERSION} MATCHES ".*Open MPI.*") target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) - elseif (MPI_VERSION STREQUAL "") + elseif(MPI_VERSION STREQUAL "") set(MPI_FOUND FALSE) message( - WARNING - "MPI found but mpirun is not available. Building without MPI." - ) + WARNING "MPI found but mpirun is not available. Building without MPI.") else() set(MPI_FOUND FALSE) - message( - WARNING - "MPI which is not OpenMPI found. Building without MPI." - ) - endif() + message(WARNING "MPI which is not OpenMPI found. Building without MPI.") + endif() endif() add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) target_include_directories( - mlx - PUBLIC - $ - $ -) + mlx PUBLIC $ + $) -FetchContent_Declare(fmt +FetchContent_Declare( + fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git - GIT_TAG 10.2.1 - EXCLUDE_FROM_ALL -) + GIT_TAG 10.2.1 + EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(fmt) target_link_libraries(mlx PRIVATE fmt::fmt-header-only) -if (MLX_BUILD_PYTHON_BINDINGS) +if(MLX_BUILD_PYTHON_BINDINGS) message(STATUS "Building Python bindings.") - find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + find_package( + Python 3.8 + COMPONENTS Interpreter Development.Module + REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE NB_DIR) list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") find_package(nanobind CONFIG REQUIRED) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) endif() -if (MLX_BUILD_TESTS) +if(MLX_BUILD_TESTS) include(CTest) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) endif() -if (MLX_BUILD_EXAMPLES) +if(MLX_BUILD_EXAMPLES) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) endif() -if (MLX_BUILD_BENCHMARKS) +if(MLX_BUILD_BENCHMARKS) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) endif() - - # ----------------------------- Installation ----------------------------- include(GNUInstallDirs) # Install library install( - TARGETS mlx - EXPORT MLXTargets - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} -) - + TARGETS mlx + EXPORT MLXTargets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) # Install headers install( - DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - COMPONENT headers - FILES_MATCHING PATTERN "*.h" -) + DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + COMPONENT headers + FILES_MATCHING + PATTERN "*.h") # Install metal dependencies -if (MLX_BUILD_METAL) +if(MLX_BUILD_METAL) # Install metal cpp install( - DIRECTORY ${metal_cpp_SOURCE_DIR}/ - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp - COMPONENT metal_cpp_source - ) + DIRECTORY ${metal_cpp_SOURCE_DIR}/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp + COMPONENT metal_cpp_source) endif() @@ -273,31 +262,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX) install( EXPORT MLXTargets FILE MLXTargets.cmake - DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} -) + DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) include(CMakePackageConfigHelpers) write_basic_package_version_file( ${MLX_CMAKE_BUILD_VERSION_CONFIG} COMPATIBILITY SameMajorVersion - VERSION ${MLX_VERSION} -) + VERSION ${MLX_VERSION}) configure_package_config_file( - ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in - ${MLX_CMAKE_BUILD_CONFIG} + ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG} INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} NO_CHECK_REQUIRED_COMPONENTS_MACRO - PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR -) + PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR + MLX_CMAKE_INSTALL_MODULE_DIR) -install( - FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} - DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} -) +install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} + DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) -install( - DIRECTORY ${CMAKE_MODULE_PATH}/ - DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} -) +install(DIRECTORY ${CMAKE_MODULE_PATH}/ + DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) diff --git a/cmake/extension.cmake b/cmake/extension.cmake index ffb02ee41..6f2354897 100644 --- a/cmake/extension.cmake +++ b/cmake/extension.cmake @@ -1,56 +1,41 @@ include(CMakeParseArguments) -############################################################################### +# ############################################################################## # Build metal library # # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} # -# Args: -# TARGET: Custom target to be added for the metal library -# TITLE: Name of the .metallib -# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib -# SOURCES: List of source files -# INCLUDE_DIRS: List of include dirs -# DEPS: List of dependency files (like headers) +# Args: TARGET: Custom target to be added for the metal library TITLE: Name of +# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List +# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency +# files (like headers) # macro(mlx_build_metallib) # Parse args set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) - cmake_parse_arguments( - MTLLIB - "" - "${oneValueArgs}" - "${multiValueArgs}" - ${ARGN} - ) + cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) # Set output set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") - # Collect compile options + # Collect compile options set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) # Prepare metallib build command add_custom_command( OUTPUT ${MTLLIB_BUILD_TARGET} - COMMAND xcrun -sdk macosx metal - "$" - ${MTLLIB_COMPILE_OPTIONS} - ${MTLLIB_SOURCES} - -o ${MTLLIB_BUILD_TARGET} + COMMAND + xcrun -sdk macosx metal + "$" + ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET} DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} COMMAND_EXPAND_LISTS COMMENT "Building ${MTLLIB_TITLE}.metallib" - VERBATIM - ) + VERBATIM) # Add metallib custom target - add_custom_target( - ${MTLLIB_TARGET} - DEPENDS - ${MTLLIB_BUILD_TARGET} - ) + add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET}) -endmacro(mlx_build_metallib) \ No newline at end of file +endmacro(mlx_build_metallib) diff --git a/examples/extensions/CMakeLists.txt b/examples/extensions/CMakeLists.txt index b58a51176..1bdb03488 100644 --- a/examples/extensions/CMakeLists.txt +++ b/examples/extensions/CMakeLists.txt @@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) # ----------------------------- Dependencies ----------------------------- find_package(MLX CONFIG REQUIRED) -find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) +find_package( + Python 3.8 + COMPONENTS Interpreter Development.Module + REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE NB_DIR) list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") find_package(nanobind CONFIG REQUIRED) @@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED) add_library(mlx_ext) # Add sources -target_sources( - mlx_ext - PUBLIC - ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp -) +target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp) # Add include headers -target_include_directories( - mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR} -) +target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}) # Link to mlx target_link_libraries(mlx_ext PUBLIC mlx) @@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx) # Build metallib if(MLX_BUILD_METAL) mlx_build_metallib( - TARGET mlx_ext_metallib - TITLE mlx_ext - SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal - INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} - OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} - ) - - add_dependencies( - mlx_ext + TARGET mlx_ext_metallib - ) + TITLE + mlx_ext + SOURCES + ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal + INCLUDE_DIRS + ${PROJECT_SOURCE_DIR} + ${MLX_INCLUDE_DIRS} + OUTPUT_DIRECTORY + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) + + add_dependencies(mlx_ext mlx_ext_metallib) endif() # ----------------------------- Python Bindings ----------------------------- nanobind_add_module( _ext - NB_STATIC STABLE_ABI LTO NOMINSIZE - NB_DOMAIN mlx - ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp -) + NB_STATIC + STABLE_ABI + LTO + NOMINSIZE + NB_DOMAIN + mlx + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp) target_link_libraries(_ext PRIVATE mlx_ext) if(BUILD_SHARED_LIBS) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index f62772571..c30177966 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -1,26 +1,24 @@ target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h -) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) -if (MLX_BUILD_CPU) +if(MLX_BUILD_CPU) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) @@ -28,17 +26,15 @@ endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) -if (MLX_BUILD_ACCELERATE) +if(MLX_BUILD_ACCELERATE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) elseif(MLX_BUILD_CPU) target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp - ) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp) endif() -if (MLX_BUILD_METAL) +if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) diff --git a/mlx/backend/accelerate/CMakeLists.txt b/mlx/backend/accelerate/CMakeLists.txt index e3c16910a..f718e19de 100644 --- a/mlx/backend/accelerate/CMakeLists.txt +++ b/mlx/backend/accelerate/CMakeLists.txt @@ -1,10 +1,8 @@ target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp -) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp) diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index aa0f3dab0..56343ada4 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,5 +1,4 @@ - -if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") +if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(COMPILER ${CMAKE_C_COMPILER}) set(CLANG TRUE) else() @@ -7,72 +6,55 @@ else() endif() add_custom_command( - OUTPUT compiled_preamble.cpp - COMMAND /bin/bash - ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh - ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp - ${COMPILER} - ${PROJECT_SOURCE_DIR} - ${CLANG} + OUTPUT compiled_preamble.cpp + COMMAND + /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} + ${PROJECT_SOURCE_DIR} ${CLANG} + DEPENDS make_compiled_preamble.sh + compiled_preamble.h + ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h + ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h + ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h + ${PROJECT_SOURCE_DIR}/mlx/types/complex.h + ops.h) - DEPENDS make_compiled_preamble.sh - compiled_preamble.h - ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h - ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h - ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h - ${PROJECT_SOURCE_DIR}/mlx/types/complex.h - ops.h -) - -add_custom_target( - cpu_compiled_preamble - DEPENDS compiled_preamble.cpp -) +add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) add_dependencies(mlx cpu_compiled_preamble) target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp - ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp -) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) -if (IOS) - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp - ) +if(IOS) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) else() - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp) endif() diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 5ab67aeb7..7b2949ade 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -1,99 +1,56 @@ function(make_jit_source SRC_FILE) - # This function takes a metal header file, - # runs the C preprocessesor on it, and makes - # the processed contents available as a string in a C++ function + # This function takes a metal header file, runs the C preprocessesor on it, + # and makes the processed contents available as a string in a C++ function # mlx::core::metal::${SRC_NAME}() # - # To use the function, declare it in jit/includes.h and - # include jit/includes.h. + # To use the function, declare it in jit/includes.h and include + # jit/includes.h. # - # Additional arguments to this function are treated as dependencies - # in the Cmake build system. + # Additional arguments to this function are treated as dependencies in the + # Cmake build system. get_filename_component(SRC_NAME ${SRC_FILE} NAME) add_custom_command( - OUTPUT jit/${SRC_NAME}.cpp - COMMAND /bin/bash - ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh - ${CMAKE_CURRENT_BINARY_DIR}/jit - ${CMAKE_C_COMPILER} - ${PROJECT_SOURCE_DIR} - ${SRC_FILE} - "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}" - DEPENDS make_compiled_preamble.sh - kernels/${SRC_FILE}.h - ${ARGN} - ) + OUTPUT jit/${SRC_NAME}.cpp + COMMAND + /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} + ${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}" + DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN}) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_dependencies(mlx ${SRC_NAME}) - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) endfunction(make_jit_source) -make_jit_source( - utils - kernels/bf16.h - kernels/complex.h - kernels/defines.h -) -make_jit_source( - unary_ops - kernels/erf.h - kernels/expm1f.h -) +make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h) +make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) -make_jit_source( - reduce_utils - kernels/atomic.h - kernels/reduction/ops.h -) +make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) make_jit_source(scatter) make_jit_source(gather) make_jit_source(hadamard) -if (MLX_METAL_JIT) - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp - ) +if(MLX_METAL_JIT) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp) make_jit_source(arange) make_jit_source(copy) make_jit_source(unary) make_jit_source(binary) make_jit_source(binary_two) - make_jit_source( - fft - kernels/fft/radix.h - kernels/fft/readwrite.h - ) + make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(ternary) make_jit_source(softmax) make_jit_source(scan) make_jit_source(sort) make_jit_source( - reduce - kernels/reduction/reduce_all.h - kernels/reduction/reduce_col.h - kernels/reduction/reduce_row.h - kernels/reduction/reduce_init.h - ) + reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h + kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h) make_jit_source( - steel/gemm/gemm - kernels/steel/utils.h - kernels/steel/gemm/loader.h - kernels/steel/gemm/mma.h - kernels/steel/gemm/params.h - kernels/steel/gemm/transforms.h - ) + steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h + kernels/steel/gemm/mma.h kernels/steel/gemm/params.h + kernels/steel/gemm/transforms.h) make_jit_source(steel/gemm/kernels/steel_gemm_fused) - make_jit_source( - steel/gemm/kernels/steel_gemm_masked - kernels/steel/defines.h - ) + make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source( steel/conv/conv @@ -104,63 +61,51 @@ if (MLX_METAL_JIT) kernels/steel/conv/params.h kernels/steel/conv/loader.h kernels/steel/conv/loaders/loader_channel_l.h - kernels/steel/conv/loaders/loader_channel_n.h - ) - make_jit_source( - steel/conv/kernels/steel_conv - ) - make_jit_source( - steel/conv/kernels/steel_conv_general - kernels/steel/defines.h - kernels/steel/conv/loaders/loader_general.h - ) + kernels/steel/conv/loaders/loader_channel_n.h) + make_jit_source(steel/conv/kernels/steel_conv) + make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h + kernels/steel/conv/loaders/loader_general.h) make_jit_source(quantized) make_jit_source(gemv_masked) else() - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) endif() target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp -) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) -if (NOT MLX_METAL_PATH) +if(NOT MLX_METAL_PATH) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) -target_compile_definitions( - mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") +target_compile_definitions(mlx + PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 8699565a5..4d637a154 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -1,38 +1,26 @@ -set( - BASE_HEADERS - bf16.h - bf16_math.h - complex.h - defines.h - expm1f.h - utils.h -) +set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h) function(build_kernel_base TARGET SRCFILE DEPS) set(METAL_FLAGS -Wall -Wextra -fno-fast-math) if(MLX_METAL_DEBUG) - set(METAL_FLAGS ${METAL_FLAGS} - -gline-tables-only - -frecord-sources) + set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() add_custom_command( - COMMAND xcrun -sdk macosx metal - ${METAL_FLAGS} - -c ${SRCFILE} - -I${PROJECT_SOURCE_DIR} - -o ${TARGET}.air + COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} + -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} OUTPUT ${TARGET}.air COMMENT "Building ${TARGET}.air" - VERBATIM - ) + VERBATIM) endfunction(build_kernel_base) function(build_kernel KERNEL) set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) cmake_path(GET KERNEL STEM TARGET) build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}") - set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE) + set(KERNEL_AIR + ${TARGET}.air ${KERNEL_AIR} + PARENT_SCOPE) endfunction(build_kernel) build_kernel(arg_reduce) @@ -42,106 +30,66 @@ build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) build_kernel(rope) -build_kernel( - scaled_dot_product_attention - scaled_dot_product_attention_params.h - steel/defines.h - steel/gemm/transforms.h - steel/utils.h -) +build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h + steel/defines.h steel/gemm/transforms.h steel/utils.h) -set( - STEEL_HEADERS - steel/defines.h - steel/utils.h - steel/conv/conv.h - steel/conv/loader.h - steel/conv/loaders/loader_channel_l.h - steel/conv/loaders/loader_channel_n.h - steel/conv/loaders/loader_general.h - steel/conv/kernels/steel_conv.h - steel/conv/kernels/steel_conv_general.h - steel/gemm/gemm.h - steel/gemm/mma.h - steel/gemm/loader.h - steel/gemm/transforms.h - steel/gemm/kernels/steel_gemm_fused.h - steel/gemm/kernels/steel_gemm_masked.h - steel/gemm/kernels/steel_gemm_splitk.h -) +set(STEEL_HEADERS + steel/defines.h + steel/utils.h + steel/conv/conv.h + steel/conv/loader.h + steel/conv/loaders/loader_channel_l.h + steel/conv/loaders/loader_channel_n.h + steel/conv/loaders/loader_general.h + steel/conv/kernels/steel_conv.h + steel/conv/kernels/steel_conv_general.h + steel/gemm/gemm.h + steel/gemm/mma.h + steel/gemm/loader.h + steel/gemm/transforms.h + steel/gemm/kernels/steel_gemm_fused.h + steel/gemm/kernels/steel_gemm_masked.h + steel/gemm/kernels/steel_gemm_splitk.h) -if (NOT MLX_METAL_JIT) -build_kernel(arange arange.h) -build_kernel(binary binary.h binary_ops.h) -build_kernel(binary_two binary_two.h) -build_kernel(copy copy.h) -build_kernel( - fft - fft.h - fft/radix.h - fft/readwrite.h -) -build_kernel( - reduce - atomic.h - reduction/ops.h - reduction/reduce_init.h - reduction/reduce_all.h - reduction/reduce_col.h - reduction/reduce_row.h -) -build_kernel( - quantized - quantized.h - ${STEEL_HEADERS} -) -build_kernel(scan scan.h) -build_kernel(softmax softmax.h) -build_kernel(sort sort.h) -build_kernel(ternary ternary.h ternary_ops.h) -build_kernel(unary unary.h unary_ops.h) -build_kernel( - steel/conv/kernels/steel_conv - ${STEEL_HEADERS} -) -build_kernel( - steel/conv/kernels/steel_conv_general - ${STEEL_HEADERS} -) -build_kernel( - steel/gemm/kernels/steel_gemm_fused - ${STEEL_HEADERS} -) -build_kernel( - steel/gemm/kernels/steel_gemm_masked - ${STEEL_HEADERS} -) -build_kernel( - steel/gemm/kernels/steel_gemm_splitk - ${STEEL_HEADERS} -) -build_kernel(gemv_masked steel/utils.h) +if(NOT MLX_METAL_JIT) + build_kernel(arange arange.h) + build_kernel(binary binary.h binary_ops.h) + build_kernel(binary_two binary_two.h) + build_kernel(copy copy.h) + build_kernel(fft fft.h fft/radix.h fft/readwrite.h) + build_kernel( + reduce + atomic.h + reduction/ops.h + reduction/reduce_init.h + reduction/reduce_all.h + reduction/reduce_col.h + reduction/reduce_row.h) + build_kernel(quantized quantized.h ${STEEL_HEADERS}) + build_kernel(scan scan.h) + build_kernel(softmax softmax.h) + build_kernel(sort sort.h) + build_kernel(ternary ternary.h ternary_ops.h) + build_kernel(unary unary.h unary_ops.h) + build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) + build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) + build_kernel(gemv_masked steel/utils.h) endif() - add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib - COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib + COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o + ${MLX_METAL_PATH}/mlx.metallib DEPENDS ${KERNEL_AIR} COMMENT "Building mlx.metallib" - VERBATIM -) + VERBATIM) -add_custom_target( - mlx-metallib - DEPENDS - ${MLX_METAL_PATH}/mlx.metallib -) +add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib) -add_dependencies( - mlx - mlx-metallib -) +add_dependencies(mlx mlx-metallib) # Install metallib include(GNUInstallDirs) @@ -149,5 +97,4 @@ include(GNUInstallDirs) install( FILES ${MLX_METAL_PATH}/mlx.metallib DESTINATION ${CMAKE_INSTALL_LIBDIR} - COMPONENT metallib -) + COMPONENT metallib) diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index 50a30da43..6c5f8017a 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,11 +1,9 @@ target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp -) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp) diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_metal/CMakeLists.txt index 8f507e771..f31619e69 100644 --- a/mlx/backend/no_metal/CMakeLists.txt +++ b/mlx/backend/no_metal/CMakeLists.txt @@ -1,8 +1,6 @@ target_sources( mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp -) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index d7521a365..4009196eb 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -1,16 +1,8 @@ -target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp -) +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp) -if (MPI_FOUND AND MLX_BUILD_CPU) +if(MPI_FOUND AND MLX_BUILD_CPU) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) else() - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp) endif() diff --git a/mlx/distributed/mpi/CMakeLists.txt b/mlx/distributed/mpi/CMakeLists.txt index 3caca724c..0e47d4347 100644 --- a/mlx/distributed/mpi/CMakeLists.txt +++ b/mlx/distributed/mpi/CMakeLists.txt @@ -1,5 +1 @@ -target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp -) +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt index 14a39df73..2402ff588 100644 --- a/mlx/io/CMakeLists.txt +++ b/mlx/io/CMakeLists.txt @@ -1,58 +1,32 @@ +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp) -target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp -) - -if (MLX_BUILD_SAFETENSORS) - MESSAGE(STATUS "Downloading json") - FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +if(MLX_BUILD_SAFETENSORS) + message(STATUS "Downloading json") + FetchContent_Declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) FetchContent_MakeAvailable(json) target_include_directories( - mlx PRIVATE - $ - ) - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp - ) + mlx PRIVATE $) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp) else() - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp) endif() -if (MLX_BUILD_GGUF) - MESSAGE(STATUS "Downloading gguflib") - FetchContent_Declare(gguflib - GIT_REPOSITORY https://github.com/antirez/gguf-tools/ - GIT_TAG af7d88d808a7608a33723fba067036202910acb3 - ) +if(MLX_BUILD_GGUF) + message(STATUS "Downloading gguflib") + FetchContent_Declare( + gguflib + GIT_REPOSITORY https://github.com/antirez/gguf-tools/ + GIT_TAG af7d88d808a7608a33723fba067036202910acb3) FetchContent_MakeAvailable(gguflib) - target_include_directories( - mlx PRIVATE - $ - ) - add_library( - gguflib STATIC - ${gguflib_SOURCE_DIR}/fp16.c - ${gguflib_SOURCE_DIR}/gguflib.c) + target_include_directories(mlx + PRIVATE $) + add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c + ${gguflib_SOURCE_DIR}/gguflib.c) target_link_libraries(mlx PRIVATE $) - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp) else() - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp - ) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp) endif() - diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index c74ce9c95..104ad6d69 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -1,7 +1,11 @@ nanobind_add_module( core - NB_STATIC STABLE_ABI LTO NOMINSIZE - NB_DOMAIN mlx + NB_STATIC + STABLE_ABI + LTO + NOMINSIZE + NB_DOMAIN + mlx ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp @@ -19,19 +23,14 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp -) + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) -if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) +if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) endif() -set_target_properties( - core - PROPERTIES - LIBRARY_OUTPUT_DIRECTORY - ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY} -) +set_target_properties(core PROPERTIES LIBRARY_OUTPUT_DIRECTORY + ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}) target_link_libraries(core PRIVATE mlx) target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 42bf66580..45f85de64 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,43 +1,39 @@ FetchContent_Declare( doctest GIT_REPOSITORY "https://github.com/onqtam/doctest" - GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2" -) + GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2") FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) -if (MLX_BUILD_METAL) - set( - METAL_TEST_SOURCES - metal_tests.cpp - ) +if(MLX_BUILD_METAL) + set(METAL_TEST_SOURCES metal_tests.cpp) endif() include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) -target_sources(tests PRIVATE - allocator_tests.cpp - array_tests.cpp - arg_reduce_tests.cpp - autograd_tests.cpp - blas_tests.cpp - compile_tests.cpp - custom_vjp_tests.cpp - creations_tests.cpp - device_tests.cpp - einsum_tests.cpp - eval_tests.cpp - fft_tests.cpp - load_tests.cpp - ops_tests.cpp - random_tests.cpp - scheduler_tests.cpp - utils_tests.cpp - vmap_tests.cpp - linalg_tests.cpp - ${METAL_TEST_SOURCES} -) +target_sources( + tests + PRIVATE allocator_tests.cpp + array_tests.cpp + arg_reduce_tests.cpp + autograd_tests.cpp + blas_tests.cpp + compile_tests.cpp + custom_vjp_tests.cpp + creations_tests.cpp + device_tests.cpp + einsum_tests.cpp + eval_tests.cpp + fft_tests.cpp + load_tests.cpp + ops_tests.cpp + random_tests.cpp + scheduler_tests.cpp + utils_tests.cpp + vmap_tests.cpp + linalg_tests.cpp + ${METAL_TEST_SOURCES}) target_link_libraries(tests PRIVATE mlx doctest) doctest_discover_tests(tests)