mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fallback for cuda < 12.8 (#2697)
This commit is contained in:
@@ -59,6 +59,11 @@ target_sources(
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
||||||
|
|
||||||
|
# fp4 is not available on < 12.8
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
|
||||||
|
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||||
|
|||||||
56
mlx/backend/cuda/quantized/cuda_fp4.h
Normal file
56
mlx/backend/cuda/quantized/cuda_fp4.h
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
struct __nv_fp4_e2m1 {
|
||||||
|
__device__ __nv_fp4_e2m1(float x) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
__x = 0x7;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
|
||||||
|
x = std::abs(x);
|
||||||
|
|
||||||
|
if (x > 5.0f) {
|
||||||
|
__x = 0x7;
|
||||||
|
} else if (x >= 3.5f) {
|
||||||
|
__x = 0x6;
|
||||||
|
} else if (x > 2.5f) {
|
||||||
|
__x = 0x5;
|
||||||
|
} else if (x >= 1.75f) {
|
||||||
|
__x = 0x4;
|
||||||
|
} else if (x > 1.25f) {
|
||||||
|
__x = 0x3;
|
||||||
|
} else if (x >= 0.75f) {
|
||||||
|
__x = 0x2;
|
||||||
|
} else if (x > 0.25f) {
|
||||||
|
__x = 0x1;
|
||||||
|
} else {
|
||||||
|
__x = 0x0;
|
||||||
|
}
|
||||||
|
__x |= sign_bit;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ operator float() {
|
||||||
|
static const float LUT[16] = {
|
||||||
|
0.0f,
|
||||||
|
0.5f,
|
||||||
|
1.0f,
|
||||||
|
1.5f,
|
||||||
|
2.0f,
|
||||||
|
3.0f,
|
||||||
|
4.0f,
|
||||||
|
6.0f,
|
||||||
|
-0.0f,
|
||||||
|
-0.5f,
|
||||||
|
-1.0f,
|
||||||
|
-1.5f,
|
||||||
|
-2.0f,
|
||||||
|
-3.0f,
|
||||||
|
-4.0f,
|
||||||
|
-6.0f
|
||||||
|
};
|
||||||
|
|
||||||
|
return LUT[__x];
|
||||||
|
}
|
||||||
|
uint8_t __x{0};
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user