diff --git a/CMakeLists.txt b/CMakeLists.txt index 84d4198ba..0dbc0b51b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW) add_library(mlx) +# Supress warnings: note: parameter passing for argument of type +# ‘std::pair’ when C++17 is enabled changed to match C++14 in GCC +# 10.1 +target_compile_options(mlx PRIVATE -Wno-psabi) + if(MLX_BUILD_CUDA) enable_language(CUDA) endif() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 5c4bf7115..eabee94f2 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -52,6 +52,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) @@ -170,11 +171,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all) # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) -# Supress warnings: note: parameter passing for argument of type -# ‘std::pair’ when C++17 is enabled changed to match C++14 in GCC -# 10.1 -target_compile_options(mlx PRIVATE -Wno-psabi) - # Install CCCL headers for JIT. install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index fcd083f2f..e15d1cf52 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -334,4 +334,17 @@ struct Tanh { } }; +struct ToFP8 { + template + __device__ uint8_t operator()(T x) { + return __nv_fp8_e4m3(x).__x; + } +}; + +struct FromFP8 { + __device__ float operator()(uint8_t x) { + return float(*(__nv_fp8_e4m3*)(&x)); + } +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/quantized/convert_fp8.cu b/mlx/backend/cuda/quantized/convert_fp8.cu new file mode 100644 index 000000000..c0be2e381 --- /dev/null +++ b/mlx/backend/cuda/quantized/convert_fp8.cu @@ -0,0 +1,19 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/unary/unary.cuh" +#include "mlx/fast_primitives.h" + +namespace mlx::core { +void fast::ConvertFP8::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("ConvertFP8::eval_gpu"); + auto& in = inputs[0]; + auto& out = outputs[0]; + auto& s = out.primitive().stream(); + if (to_fp8_) { + unary_op_gpu(inputs, out, name(), s); + } else { + unary_op_gpu(inputs, out, name(), s); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/unary.cuh b/mlx/backend/cuda/unary/unary.cuh index a20e119ca..8f4a02d50 100644 --- a/mlx/backend/cuda/unary/unary.cuh +++ b/mlx/backend/cuda/unary/unary.cuh @@ -108,6 +108,12 @@ constexpr bool supports_unary_op() { if (std::is_same_v) { return std::is_same_v && std::is_same_v; } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } return false; } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 34ac31a03..2b787e63a 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2,7 +2,6 @@ // Required for using M_PI_2 in MSVC. #define _USE_MATH_DEFINES - #include #include @@ -4038,6 +4037,7 @@ TEST_CASE("test fp8 conversion") { array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t); auto in_fp8 = to_fp8(in); auto out = from_fp8(in_fp8, t); + CHECK(array_equal(out, in).item()); } array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});