From 8be324c26e49503b464b69f81931c86f9c1d139a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 23 Oct 2025 09:43:36 -0700 Subject: [PATCH] fallback for cuda < 12.8 (#2697) --- mlx/backend/cuda/CMakeLists.txt | 5 +++ mlx/backend/cuda/quantized/cuda_fp4.h | 56 +++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 mlx/backend/cuda/quantized/cuda_fp4.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 19cafb932..7f8f1aade 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -59,6 +59,11 @@ target_sources( add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) +# fp4 is not available on < 12.8 +if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0) + target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/) +endif() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) diff --git a/mlx/backend/cuda/quantized/cuda_fp4.h b/mlx/backend/cuda/quantized/cuda_fp4.h new file mode 100644 index 000000000..abda0df70 --- /dev/null +++ b/mlx/backend/cuda/quantized/cuda_fp4.h @@ -0,0 +1,56 @@ +#pragma once + +struct __nv_fp4_e2m1 { + __device__ __nv_fp4_e2m1(float x) { + if (std::isnan(x)) { + __x = 0x7; + return; + } + + const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0; + x = std::abs(x); + + if (x > 5.0f) { + __x = 0x7; + } else if (x >= 3.5f) { + __x = 0x6; + } else if (x > 2.5f) { + __x = 0x5; + } else if (x >= 1.75f) { + __x = 0x4; + } else if (x > 1.25f) { + __x = 0x3; + } else if (x >= 0.75f) { + __x = 0x2; + } else if (x > 0.25f) { + __x = 0x1; + } else { + __x = 0x0; + } + __x |= sign_bit; + } + + __device__ operator float() { + static const float LUT[16] = { + 0.0f, + 0.5f, + 1.0f, + 1.5f, + 2.0f, + 3.0f, + 4.0f, + 6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f + }; + + return LUT[__x]; + } + uint8_t __x{0}; +};