[CUDA] Implement Scan kernel (#2347)

* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
This commit is contained in:
Cheng
2025-07-11 08:54:12 +09:00
committed by GitHub
parent b6eec20260
commit 8347575ba1
13 changed files with 815 additions and 64 deletions

View File

@@ -35,6 +35,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
@@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.