mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Chore: add pre-commit hook for cmake (#1362)
* reset and lint * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
adcc88e208
commit
669c27140d
@ -14,3 +14,7 @@ repos:
|
|||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
- --profile=black
|
- --profile=black
|
||||||
|
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
||||||
|
rev: v0.6.13
|
||||||
|
hooks:
|
||||||
|
- id: cmake-format
|
||||||
|
172
CMakeLists.txt
172
CMakeLists.txt
@ -29,18 +29,23 @@ endif()
|
|||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
|
||||||
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
message(
|
||||||
|
STATUS
|
||||||
|
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
set(MLX_BUILD_ARM OFF)
|
set(MLX_BUILD_ARM OFF)
|
||||||
|
|
||||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||||
if(NOT MLX_ENABLE_X64_MAC)
|
if(NOT MLX_ENABLE_X64_MAC)
|
||||||
message(FATAL_ERROR
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
"Building for x86_64 on macOS is not supported."
|
"Building for x86_64 on macOS is not supported."
|
||||||
" If you are on an Apple silicon system, check the build"
|
" If you are on an Apple silicon system, check the build"
|
||||||
" documentation for possible fixes: "
|
" documentation for possible fixes: "
|
||||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||||
|
)
|
||||||
else()
|
else()
|
||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
@ -61,63 +66,59 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
if (MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
find_library(METAL_LIB Metal)
|
find_library(METAL_LIB Metal)
|
||||||
find_library(FOUNDATION_LIB Foundation)
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
find_library(QUARTZ_LIB QuartzCore)
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
set(MLX_METAL_DEBUG OFF)
|
set(MLX_METAL_DEBUG OFF)
|
||||||
elseif (MLX_BUILD_METAL)
|
elseif(MLX_BUILD_METAL)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Building METAL sources")
|
||||||
|
|
||||||
if (MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
add_compile_definitions(MLX_METAL_DEBUG)
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
execute_process(
|
||||||
OUTPUT_VARIABLE MACOS_VERSION
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if (${MACOS_VERSION} LESS 14.0)
|
if(${MACOS_VERSION} LESS 14.0)
|
||||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
|
set(METAL_CPP_URL
|
||||||
|
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||||
|
)
|
||||||
# Get the metal version
|
# Get the metal version
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
COMMAND
|
||||||
OUTPUT_VARIABLE MLX_METAL_VERSION
|
zsh "-c"
|
||||||
COMMAND_ERROR_IS_FATAL ANY)
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
metal_cpp
|
|
||||||
URL ${METAL_CPP_URL}
|
|
||||||
)
|
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
)
|
|
||||||
target_link_libraries(
|
|
||||||
mlx PUBLIC
|
|
||||||
${METAL_LIB}
|
|
||||||
${FOUNDATION_LIB}
|
|
||||||
${QUARTZ_LIB})
|
|
||||||
|
|
||||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||||
@ -129,31 +130,28 @@ if (MLX_BUILD_CPU)
|
|||||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||||
# openblas instead.
|
# openblas instead.
|
||||||
set(BLA_VENDOR OpenBLAS)
|
set(BLA_VENDOR OpenBLAS)
|
||||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
set(LAPACK_ROOT
|
||||||
|
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||||
endif()
|
endif()
|
||||||
# Search and link with lapack.
|
# Search and link with lapack.
|
||||||
find_package(LAPACK REQUIRED)
|
find_package(LAPACK REQUIRED)
|
||||||
if (NOT LAPACK_FOUND)
|
if(NOT LAPACK_FOUND)
|
||||||
message(FATAL_ERROR "Must have LAPACK installed")
|
message(FATAL_ERROR "Must have LAPACK installed")
|
||||||
endif()
|
endif()
|
||||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
||||||
/usr/include
|
|
||||||
/usr/local/include
|
|
||||||
/usr/local/opt/openblas/include)
|
/usr/local/opt/openblas/include)
|
||||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
# List blas after lapack otherwise we may accidentally incldue an old
|
||||||
# of lapack.h from the include dirs of blas.
|
# version of lapack.h from the include dirs of blas.
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
if (NOT BLAS_FOUND)
|
if(NOT BLAS_FOUND)
|
||||||
message(FATAL_ERROR "Must have BLAS installed")
|
message(FATAL_ERROR "Must have BLAS installed")
|
||||||
endif()
|
endif()
|
||||||
# TODO find a cleaner way to do this
|
# TODO find a cleaner way to do this
|
||||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
||||||
/usr/include
|
|
||||||
/usr/local/include
|
|
||||||
$ENV{BLAS_HOME}/include)
|
$ENV{BLAS_HOME}/include)
|
||||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
@ -165,72 +163,65 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(MPI)
|
find_package(MPI)
|
||||||
if (MPI_FOUND)
|
if(MPI_FOUND)
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "mpirun --version"
|
COMMAND zsh "-c" "mpirun --version"
|
||||||
OUTPUT_VARIABLE MPI_VERSION
|
OUTPUT_VARIABLE MPI_VERSION
|
||||||
ERROR_QUIET
|
ERROR_QUIET)
|
||||||
)
|
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||||
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
|
|
||||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||||
elseif (MPI_VERSION STREQUAL "")
|
elseif(MPI_VERSION STREQUAL "")
|
||||||
set(MPI_FOUND FALSE)
|
set(MPI_FOUND FALSE)
|
||||||
message(
|
message(
|
||||||
WARNING
|
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||||
"MPI found but mpirun is not available. Building without MPI."
|
|
||||||
)
|
|
||||||
else()
|
else()
|
||||||
set(MPI_FOUND FALSE)
|
set(MPI_FOUND FALSE)
|
||||||
message(
|
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||||
WARNING
|
|
||||||
"MPI which is not OpenMPI found. Building without MPI."
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
PUBLIC
|
$<INSTALL_INTERFACE:include>)
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
|
||||||
$<INSTALL_INTERFACE:include>
|
|
||||||
)
|
|
||||||
|
|
||||||
FetchContent_Declare(fmt
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
GIT_TAG 10.2.1
|
GIT_TAG 10.2.1
|
||||||
EXCLUDE_FROM_ALL
|
EXCLUDE_FROM_ALL)
|
||||||
)
|
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
||||||
|
|
||||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
find_package(
|
||||||
|
Python 3.8
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE NB_DIR)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_TESTS)
|
if(MLX_BUILD_TESTS)
|
||||||
include(CTest)
|
include(CTest)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_EXAMPLES)
|
if(MLX_BUILD_EXAMPLES)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_BENCHMARKS)
|
if(MLX_BUILD_BENCHMARKS)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
@ -241,27 +232,25 @@ install(
|
|||||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
INCLUDES
|
||||||
)
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||||
|
|
||||||
|
|
||||||
# Install headers
|
# Install headers
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
COMPONENT headers
|
COMPONENT headers
|
||||||
FILES_MATCHING PATTERN "*.h"
|
FILES_MATCHING
|
||||||
)
|
PATTERN "*.h")
|
||||||
|
|
||||||
# Install metal dependencies
|
# Install metal dependencies
|
||||||
if (MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
|
|
||||||
# Install metal cpp
|
# Install metal cpp
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||||
COMPONENT metal_cpp_source
|
COMPONENT metal_cpp_source)
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -273,31 +262,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
|||||||
install(
|
install(
|
||||||
EXPORT MLXTargets
|
EXPORT MLXTargets
|
||||||
FILE MLXTargets.cmake
|
FILE MLXTargets.cmake
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
)
|
|
||||||
|
|
||||||
include(CMakePackageConfigHelpers)
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
write_basic_package_version_file(
|
write_basic_package_version_file(
|
||||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
COMPATIBILITY SameMajorVersion
|
COMPATIBILITY SameMajorVersion
|
||||||
VERSION ${MLX_VERSION}
|
VERSION ${MLX_VERSION})
|
||||||
)
|
|
||||||
|
|
||||||
configure_package_config_file(
|
configure_package_config_file(
|
||||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
||||||
${MLX_CMAKE_BUILD_CONFIG}
|
|
||||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
||||||
)
|
MLX_CMAKE_INSTALL_MODULE_DIR)
|
||||||
|
|
||||||
install(
|
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
||||||
)
|
|
||||||
|
@ -1,30 +1,21 @@
|
|||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
###############################################################################
|
# ##############################################################################
|
||||||
# Build metal library
|
# Build metal library
|
||||||
#
|
#
|
||||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||||
#
|
#
|
||||||
# Args:
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# TARGET: Custom target to be added for the metal library
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# TITLE: Name of the .metallib
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
# files (like headers)
|
||||||
# SOURCES: List of source files
|
|
||||||
# INCLUDE_DIRS: List of include dirs
|
|
||||||
# DEPS: List of dependency files (like headers)
|
|
||||||
#
|
#
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
MTLLIB
|
|
||||||
""
|
|
||||||
"${oneValueArgs}"
|
|
||||||
"${multiValueArgs}"
|
|
||||||
${ARGN}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set output
|
# Set output
|
||||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
@ -35,22 +26,16 @@ macro(mlx_build_metallib)
|
|||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
COMMAND xcrun -sdk macosx metal
|
COMMAND
|
||||||
|
xcrun -sdk macosx metal
|
||||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||||
${MTLLIB_COMPILE_OPTIONS}
|
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||||
${MTLLIB_SOURCES}
|
|
||||||
-o ${MTLLIB_BUILD_TARGET}
|
|
||||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||||
COMMAND_EXPAND_LISTS
|
COMMAND_EXPAND_LISTS
|
||||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||||
VERBATIM
|
VERBATIM)
|
||||||
)
|
|
||||||
|
|
||||||
# Add metallib custom target
|
# Add metallib custom target
|
||||||
add_custom_target(
|
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
||||||
${MTLLIB_TARGET}
|
|
||||||
DEPENDS
|
|
||||||
${MTLLIB_BUILD_TARGET}
|
|
||||||
)
|
|
||||||
|
|
||||||
endmacro(mlx_build_metallib)
|
endmacro(mlx_build_metallib)
|
@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
|||||||
|
|
||||||
# ----------------------------- Dependencies -----------------------------
|
# ----------------------------- Dependencies -----------------------------
|
||||||
find_package(MLX CONFIG REQUIRED)
|
find_package(MLX CONFIG REQUIRED)
|
||||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
find_package(
|
||||||
|
Python 3.8
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE NB_DIR)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED)
|
|||||||
add_library(mlx_ext)
|
add_library(mlx_ext)
|
||||||
|
|
||||||
# Add sources
|
# Add sources
|
||||||
target_sources(
|
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
|
||||||
mlx_ext
|
|
||||||
PUBLIC
|
|
||||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add include headers
|
# Add include headers
|
||||||
target_include_directories(
|
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
|
||||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Link to mlx
|
# Link to mlx
|
||||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
|||||||
# Build metallib
|
# Build metallib
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
mlx_build_metallib(
|
mlx_build_metallib(
|
||||||
TARGET mlx_ext_metallib
|
TARGET
|
||||||
TITLE mlx_ext
|
|
||||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
|
||||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
|
||||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
|
||||||
)
|
|
||||||
|
|
||||||
add_dependencies(
|
|
||||||
mlx_ext
|
|
||||||
mlx_ext_metallib
|
mlx_ext_metallib
|
||||||
)
|
TITLE
|
||||||
|
mlx_ext
|
||||||
|
SOURCES
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||||
|
INCLUDE_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}
|
||||||
|
${MLX_INCLUDE_DIRS}
|
||||||
|
OUTPUT_DIRECTORY
|
||||||
|
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||||
|
|
||||||
|
add_dependencies(mlx_ext mlx_ext_metallib)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Python Bindings -----------------------------
|
# ----------------------------- Python Bindings -----------------------------
|
||||||
nanobind_add_module(
|
nanobind_add_module(
|
||||||
_ext
|
_ext
|
||||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
NB_STATIC
|
||||||
NB_DOMAIN mlx
|
STABLE_ABI
|
||||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
LTO
|
||||||
)
|
NOMINSIZE
|
||||||
|
NB_DOMAIN
|
||||||
|
mlx
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
|
||||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
@ -17,10 +16,9 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
)
|
|
||||||
|
|
||||||
if (MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||||
@ -28,17 +26,15 @@ endif()
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if (MLX_BUILD_ACCELERATE)
|
if(MLX_BUILD_ACCELERATE)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||||
elseif(MLX_BUILD_CPU)
|
elseif(MLX_BUILD_CPU)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
||||||
)
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|
||||||
set(COMPILER ${CMAKE_C_COMPILER})
|
set(COMPILER ${CMAKE_C_COMPILER})
|
||||||
set(CLANG TRUE)
|
set(CLANG TRUE)
|
||||||
else()
|
else()
|
||||||
@ -8,33 +7,25 @@ endif()
|
|||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT compiled_preamble.cpp
|
OUTPUT compiled_preamble.cpp
|
||||||
COMMAND /bin/bash
|
COMMAND
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||||
${COMPILER}
|
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||||
${PROJECT_SOURCE_DIR}
|
|
||||||
${CLANG}
|
|
||||||
|
|
||||||
DEPENDS make_compiled_preamble.sh
|
DEPENDS make_compiled_preamble.sh
|
||||||
compiled_preamble.h
|
compiled_preamble.h
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||||
ops.h
|
ops.h)
|
||||||
)
|
|
||||||
|
|
||||||
add_custom_target(
|
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||||
cpu_compiled_preamble
|
|
||||||
DEPENDS compiled_preamble.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
add_dependencies(mlx cpu_compiled_preamble)
|
add_dependencies(mlx cpu_compiled_preamble)
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
@ -60,19 +51,10 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||||
)
|
|
||||||
|
|
||||||
if (IOS)
|
if(IOS)
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
|
|
||||||
)
|
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
@ -1,99 +1,56 @@
|
|||||||
function(make_jit_source SRC_FILE)
|
function(make_jit_source SRC_FILE)
|
||||||
# This function takes a metal header file,
|
# This function takes a metal header file, runs the C preprocessesor on it,
|
||||||
# runs the C preprocessesor on it, and makes
|
# and makes the processed contents available as a string in a C++ function
|
||||||
# the processed contents available as a string in a C++ function
|
|
||||||
# mlx::core::metal::${SRC_NAME}()
|
# mlx::core::metal::${SRC_NAME}()
|
||||||
#
|
#
|
||||||
# To use the function, declare it in jit/includes.h and
|
# To use the function, declare it in jit/includes.h and include
|
||||||
# include jit/includes.h.
|
# jit/includes.h.
|
||||||
#
|
#
|
||||||
# Additional arguments to this function are treated as dependencies
|
# Additional arguments to this function are treated as dependencies in the
|
||||||
# in the Cmake build system.
|
# Cmake build system.
|
||||||
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT jit/${SRC_NAME}.cpp
|
OUTPUT jit/${SRC_NAME}.cpp
|
||||||
COMMAND /bin/bash
|
COMMAND
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
||||||
${CMAKE_C_COMPILER}
|
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||||
${PROJECT_SOURCE_DIR}
|
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
||||||
${SRC_FILE}
|
|
||||||
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
|
||||||
DEPENDS make_compiled_preamble.sh
|
|
||||||
kernels/${SRC_FILE}.h
|
|
||||||
${ARGN}
|
|
||||||
)
|
|
||||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||||
add_dependencies(mlx ${SRC_NAME})
|
add_dependencies(mlx ${SRC_NAME})
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
|
|
||||||
)
|
|
||||||
endfunction(make_jit_source)
|
endfunction(make_jit_source)
|
||||||
|
|
||||||
make_jit_source(
|
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
||||||
utils
|
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||||
kernels/bf16.h
|
|
||||||
kernels/complex.h
|
|
||||||
kernels/defines.h
|
|
||||||
)
|
|
||||||
make_jit_source(
|
|
||||||
unary_ops
|
|
||||||
kernels/erf.h
|
|
||||||
kernels/expm1f.h
|
|
||||||
)
|
|
||||||
make_jit_source(binary_ops)
|
make_jit_source(binary_ops)
|
||||||
make_jit_source(ternary_ops)
|
make_jit_source(ternary_ops)
|
||||||
make_jit_source(
|
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||||
reduce_utils
|
|
||||||
kernels/atomic.h
|
|
||||||
kernels/reduction/ops.h
|
|
||||||
)
|
|
||||||
make_jit_source(scatter)
|
make_jit_source(scatter)
|
||||||
make_jit_source(gather)
|
make_jit_source(gather)
|
||||||
make_jit_source(hadamard)
|
make_jit_source(hadamard)
|
||||||
|
|
||||||
if (MLX_METAL_JIT)
|
if(MLX_METAL_JIT)
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
|
||||||
)
|
|
||||||
make_jit_source(arange)
|
make_jit_source(arange)
|
||||||
make_jit_source(copy)
|
make_jit_source(copy)
|
||||||
make_jit_source(unary)
|
make_jit_source(unary)
|
||||||
make_jit_source(binary)
|
make_jit_source(binary)
|
||||||
make_jit_source(binary_two)
|
make_jit_source(binary_two)
|
||||||
make_jit_source(
|
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
||||||
fft
|
|
||||||
kernels/fft/radix.h
|
|
||||||
kernels/fft/readwrite.h
|
|
||||||
)
|
|
||||||
make_jit_source(ternary)
|
make_jit_source(ternary)
|
||||||
make_jit_source(softmax)
|
make_jit_source(softmax)
|
||||||
make_jit_source(scan)
|
make_jit_source(scan)
|
||||||
make_jit_source(sort)
|
make_jit_source(sort)
|
||||||
make_jit_source(
|
make_jit_source(
|
||||||
reduce
|
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
||||||
kernels/reduction/reduce_all.h
|
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
||||||
kernels/reduction/reduce_col.h
|
|
||||||
kernels/reduction/reduce_row.h
|
|
||||||
kernels/reduction/reduce_init.h
|
|
||||||
)
|
|
||||||
make_jit_source(
|
make_jit_source(
|
||||||
steel/gemm/gemm
|
steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
|
||||||
kernels/steel/utils.h
|
kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
|
||||||
kernels/steel/gemm/loader.h
|
kernels/steel/gemm/transforms.h)
|
||||||
kernels/steel/gemm/mma.h
|
|
||||||
kernels/steel/gemm/params.h
|
|
||||||
kernels/steel/gemm/transforms.h
|
|
||||||
)
|
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||||
make_jit_source(
|
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||||
steel/gemm/kernels/steel_gemm_masked
|
|
||||||
kernels/steel/defines.h
|
|
||||||
)
|
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||||
make_jit_source(
|
make_jit_source(
|
||||||
steel/conv/conv
|
steel/conv/conv
|
||||||
@ -104,30 +61,19 @@ if (MLX_METAL_JIT)
|
|||||||
kernels/steel/conv/params.h
|
kernels/steel/conv/params.h
|
||||||
kernels/steel/conv/loader.h
|
kernels/steel/conv/loader.h
|
||||||
kernels/steel/conv/loaders/loader_channel_l.h
|
kernels/steel/conv/loaders/loader_channel_l.h
|
||||||
kernels/steel/conv/loaders/loader_channel_n.h
|
kernels/steel/conv/loaders/loader_channel_n.h)
|
||||||
)
|
make_jit_source(steel/conv/kernels/steel_conv)
|
||||||
make_jit_source(
|
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
||||||
steel/conv/kernels/steel_conv
|
kernels/steel/conv/loaders/loader_general.h)
|
||||||
)
|
|
||||||
make_jit_source(
|
|
||||||
steel/conv/kernels/steel_conv_general
|
|
||||||
kernels/steel/defines.h
|
|
||||||
kernels/steel/conv/loaders/loader_general.h
|
|
||||||
)
|
|
||||||
make_jit_source(quantized)
|
make_jit_source(quantized)
|
||||||
make_jit_source(gemv_masked)
|
make_jit_source(gemv_masked)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
@ -153,14 +99,13 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
)
|
|
||||||
|
|
||||||
if (NOT MLX_METAL_PATH)
|
if(NOT MLX_METAL_PATH)
|
||||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||||
|
|
||||||
target_compile_definitions(
|
target_compile_definitions(mlx
|
||||||
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
||||||
|
@ -1,38 +1,26 @@
|
|||||||
set(
|
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
||||||
BASE_HEADERS
|
|
||||||
bf16.h
|
|
||||||
bf16_math.h
|
|
||||||
complex.h
|
|
||||||
defines.h
|
|
||||||
expm1f.h
|
|
||||||
utils.h
|
|
||||||
)
|
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||||
-gline-tables-only
|
|
||||||
-frecord-sources)
|
|
||||||
endif()
|
endif()
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
COMMAND xcrun -sdk macosx metal
|
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||||
${METAL_FLAGS}
|
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||||
-c ${SRCFILE}
|
|
||||||
-I${PROJECT_SOURCE_DIR}
|
|
||||||
-o ${TARGET}.air
|
|
||||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||||
OUTPUT ${TARGET}.air
|
OUTPUT ${TARGET}.air
|
||||||
COMMENT "Building ${TARGET}.air"
|
COMMENT "Building ${TARGET}.air"
|
||||||
VERBATIM
|
VERBATIM)
|
||||||
)
|
|
||||||
endfunction(build_kernel_base)
|
endfunction(build_kernel_base)
|
||||||
|
|
||||||
function(build_kernel KERNEL)
|
function(build_kernel KERNEL)
|
||||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||||
cmake_path(GET KERNEL STEM TARGET)
|
cmake_path(GET KERNEL STEM TARGET)
|
||||||
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
|
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
|
||||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
|
set(KERNEL_AIR
|
||||||
|
${TARGET}.air ${KERNEL_AIR}
|
||||||
|
PARENT_SCOPE)
|
||||||
endfunction(build_kernel)
|
endfunction(build_kernel)
|
||||||
|
|
||||||
build_kernel(arg_reduce)
|
build_kernel(arg_reduce)
|
||||||
@ -42,16 +30,10 @@ build_kernel(layer_norm)
|
|||||||
build_kernel(random)
|
build_kernel(random)
|
||||||
build_kernel(rms_norm)
|
build_kernel(rms_norm)
|
||||||
build_kernel(rope)
|
build_kernel(rope)
|
||||||
build_kernel(
|
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||||
scaled_dot_product_attention
|
steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||||
scaled_dot_product_attention_params.h
|
|
||||||
steel/defines.h
|
|
||||||
steel/gemm/transforms.h
|
|
||||||
steel/utils.h
|
|
||||||
)
|
|
||||||
|
|
||||||
set(
|
set(STEEL_HEADERS
|
||||||
STEEL_HEADERS
|
|
||||||
steel/defines.h
|
steel/defines.h
|
||||||
steel/utils.h
|
steel/utils.h
|
||||||
steel/conv/conv.h
|
steel/conv/conv.h
|
||||||
@ -67,81 +49,47 @@ set(
|
|||||||
steel/gemm/transforms.h
|
steel/gemm/transforms.h
|
||||||
steel/gemm/kernels/steel_gemm_fused.h
|
steel/gemm/kernels/steel_gemm_fused.h
|
||||||
steel/gemm/kernels/steel_gemm_masked.h
|
steel/gemm/kernels/steel_gemm_masked.h
|
||||||
steel/gemm/kernels/steel_gemm_splitk.h
|
steel/gemm/kernels/steel_gemm_splitk.h)
|
||||||
)
|
|
||||||
|
|
||||||
if (NOT MLX_METAL_JIT)
|
if(NOT MLX_METAL_JIT)
|
||||||
build_kernel(arange arange.h)
|
build_kernel(arange arange.h)
|
||||||
build_kernel(binary binary.h binary_ops.h)
|
build_kernel(binary binary.h binary_ops.h)
|
||||||
build_kernel(binary_two binary_two.h)
|
build_kernel(binary_two binary_two.h)
|
||||||
build_kernel(copy copy.h)
|
build_kernel(copy copy.h)
|
||||||
build_kernel(
|
build_kernel(fft fft.h fft/radix.h fft/readwrite.h)
|
||||||
fft
|
build_kernel(
|
||||||
fft.h
|
|
||||||
fft/radix.h
|
|
||||||
fft/readwrite.h
|
|
||||||
)
|
|
||||||
build_kernel(
|
|
||||||
reduce
|
reduce
|
||||||
atomic.h
|
atomic.h
|
||||||
reduction/ops.h
|
reduction/ops.h
|
||||||
reduction/reduce_init.h
|
reduction/reduce_init.h
|
||||||
reduction/reduce_all.h
|
reduction/reduce_all.h
|
||||||
reduction/reduce_col.h
|
reduction/reduce_col.h
|
||||||
reduction/reduce_row.h
|
reduction/reduce_row.h)
|
||||||
)
|
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||||
build_kernel(
|
build_kernel(scan scan.h)
|
||||||
quantized
|
build_kernel(softmax softmax.h)
|
||||||
quantized.h
|
build_kernel(sort sort.h)
|
||||||
${STEEL_HEADERS}
|
build_kernel(ternary ternary.h ternary_ops.h)
|
||||||
)
|
build_kernel(unary unary.h unary_ops.h)
|
||||||
build_kernel(scan scan.h)
|
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||||
build_kernel(softmax softmax.h)
|
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
||||||
build_kernel(sort sort.h)
|
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||||
build_kernel(ternary ternary.h ternary_ops.h)
|
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||||
build_kernel(unary unary.h unary_ops.h)
|
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||||
build_kernel(
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
steel/conv/kernels/steel_conv
|
|
||||||
${STEEL_HEADERS}
|
|
||||||
)
|
|
||||||
build_kernel(
|
|
||||||
steel/conv/kernels/steel_conv_general
|
|
||||||
${STEEL_HEADERS}
|
|
||||||
)
|
|
||||||
build_kernel(
|
|
||||||
steel/gemm/kernels/steel_gemm_fused
|
|
||||||
${STEEL_HEADERS}
|
|
||||||
)
|
|
||||||
build_kernel(
|
|
||||||
steel/gemm/kernels/steel_gemm_masked
|
|
||||||
${STEEL_HEADERS}
|
|
||||||
)
|
|
||||||
build_kernel(
|
|
||||||
steel/gemm/kernels/steel_gemm_splitk
|
|
||||||
${STEEL_HEADERS}
|
|
||||||
)
|
|
||||||
build_kernel(gemv_masked steel/utils.h)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||||
|
${MLX_METAL_PATH}/mlx.metallib
|
||||||
DEPENDS ${KERNEL_AIR}
|
DEPENDS ${KERNEL_AIR}
|
||||||
COMMENT "Building mlx.metallib"
|
COMMENT "Building mlx.metallib"
|
||||||
VERBATIM
|
VERBATIM)
|
||||||
)
|
|
||||||
|
|
||||||
add_custom_target(
|
add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib)
|
||||||
mlx-metallib
|
|
||||||
DEPENDS
|
|
||||||
${MLX_METAL_PATH}/mlx.metallib
|
|
||||||
)
|
|
||||||
|
|
||||||
add_dependencies(
|
add_dependencies(mlx mlx-metallib)
|
||||||
mlx
|
|
||||||
mlx-metallib
|
|
||||||
)
|
|
||||||
|
|
||||||
# Install metallib
|
# Install metallib
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
@ -149,5 +97,4 @@ include(GNUInstallDirs)
|
|||||||
install(
|
install(
|
||||||
FILES ${MLX_METAL_PATH}/mlx.metallib
|
FILES ${MLX_METAL_PATH}/mlx.metallib
|
||||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
COMPONENT metallib
|
COMPONENT metallib)
|
||||||
)
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp)
|
||||||
)
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
|
||||||
)
|
|
||||||
|
@ -1,16 +1,8 @@
|
|||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
mlx
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
if (MPI_FOUND AND MLX_BUILD_CPU)
|
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
@ -1,5 +1 @@
|
|||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
|
|
||||||
)
|
|
||||||
|
@ -1,58 +1,32 @@
|
|||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)
|
||||||
|
|
||||||
target_sources(
|
if(MLX_BUILD_SAFETENSORS)
|
||||||
mlx
|
message(STATUS "Downloading json")
|
||||||
PRIVATE
|
FetchContent_Declare(
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
json
|
||||||
)
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||||
|
|
||||||
if (MLX_BUILD_SAFETENSORS)
|
|
||||||
MESSAGE(STATUS "Downloading json")
|
|
||||||
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
|
||||||
FetchContent_MakeAvailable(json)
|
FetchContent_MakeAvailable(json)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PRIVATE
|
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||||
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)
|
||||||
)
|
|
||||||
target_sources(
|
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp
|
|
||||||
)
|
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_GGUF)
|
if(MLX_BUILD_GGUF)
|
||||||
MESSAGE(STATUS "Downloading gguflib")
|
message(STATUS "Downloading gguflib")
|
||||||
FetchContent_Declare(gguflib
|
FetchContent_Declare(
|
||||||
|
gguflib
|
||||||
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
|
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
|
||||||
GIT_TAG af7d88d808a7608a33723fba067036202910acb3
|
GIT_TAG af7d88d808a7608a33723fba067036202910acb3)
|
||||||
)
|
|
||||||
FetchContent_MakeAvailable(gguflib)
|
FetchContent_MakeAvailable(gguflib)
|
||||||
target_include_directories(
|
target_include_directories(mlx
|
||||||
mlx PRIVATE
|
PRIVATE $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>)
|
||||||
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
|
add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c
|
||||||
)
|
|
||||||
add_library(
|
|
||||||
gguflib STATIC
|
|
||||||
${gguflib_SOURCE_DIR}/fp16.c
|
|
||||||
${gguflib_SOURCE_DIR}/gguflib.c)
|
${gguflib_SOURCE_DIR}/gguflib.c)
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
|
||||||
mlx
|
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp)
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
|
|
||||||
)
|
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp)
|
||||||
mlx
|
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
nanobind_add_module(
|
nanobind_add_module(
|
||||||
core
|
core
|
||||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
NB_STATIC
|
||||||
NB_DOMAIN mlx
|
STABLE_ABI
|
||||||
|
LTO
|
||||||
|
NOMINSIZE
|
||||||
|
NB_DOMAIN
|
||||||
|
mlx
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||||
@ -19,19 +23,14 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
)
|
|
||||||
|
|
||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set_target_properties(
|
set_target_properties(core PROPERTIES LIBRARY_OUTPUT_DIRECTORY
|
||||||
core
|
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
|
||||||
PROPERTIES
|
|
||||||
LIBRARY_OUTPUT_DIRECTORY
|
|
||||||
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(core PRIVATE mlx)
|
target_link_libraries(core PRIVATE mlx)
|
||||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
||||||
|
@ -1,23 +1,20 @@
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
doctest
|
doctest
|
||||||
GIT_REPOSITORY "https://github.com/onqtam/doctest"
|
GIT_REPOSITORY "https://github.com/onqtam/doctest"
|
||||||
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2"
|
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
|
||||||
)
|
|
||||||
FetchContent_MakeAvailable(doctest)
|
FetchContent_MakeAvailable(doctest)
|
||||||
|
|
||||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||||
|
|
||||||
if (MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
set(
|
set(METAL_TEST_SOURCES metal_tests.cpp)
|
||||||
METAL_TEST_SOURCES
|
|
||||||
metal_tests.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)
|
include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)
|
||||||
|
|
||||||
target_sources(tests PRIVATE
|
target_sources(
|
||||||
allocator_tests.cpp
|
tests
|
||||||
|
PRIVATE allocator_tests.cpp
|
||||||
array_tests.cpp
|
array_tests.cpp
|
||||||
arg_reduce_tests.cpp
|
arg_reduce_tests.cpp
|
||||||
autograd_tests.cpp
|
autograd_tests.cpp
|
||||||
@ -36,8 +33,7 @@ target_sources(tests PRIVATE
|
|||||||
utils_tests.cpp
|
utils_tests.cpp
|
||||||
vmap_tests.cpp
|
vmap_tests.cpp
|
||||||
linalg_tests.cpp
|
linalg_tests.cpp
|
||||||
${METAL_TEST_SOURCES}
|
${METAL_TEST_SOURCES})
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(tests PRIVATE mlx doctest)
|
target_link_libraries(tests PRIVATE mlx doctest)
|
||||||
doctest_discover_tests(tests)
|
doctest_discover_tests(tests)
|
||||||
|
Loading…
Reference in New Issue
Block a user