CUDA backend: matmul (#2241)

This commit is contained in:
Cheng
2025-06-07 04:24:04 +09:00
committed by GitHub
parent c6a20b427a
commit 24f89173d1
7 changed files with 584 additions and 67 deletions

View File

@@ -12,6 +12,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
@@ -53,6 +54,9 @@ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Use cublasLt.
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)