This commit is contained in:
Awni Hannun
2025-10-20 09:51:30 -07:00
parent ec4c85d786
commit 73613c4d15
6 changed files with 45 additions and 6 deletions

View File

@@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()

View File

@@ -52,6 +52,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
@@ -170,11 +171,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -334,4 +334,17 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,19 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
#include "mlx/fast_primitives.h"
namespace mlx::core {
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
auto& in = inputs[0];
auto& out = outputs[0];
auto& s = out.primitive().stream();
if (to_fp8_) {
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
} else {
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, ToFP8>) {
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
}
if (std::is_same_v<Op, FromFP8>) {
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
}
return false;
}

View File

@@ -2,7 +2,6 @@
// Required for using M_PI_2 in MSVC.
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
@@ -4038,6 +4037,7 @@ TEST_CASE("test fp8 conversion") {
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
auto in_fp8 = to_fp8(in);
auto out = from_fp8(in_fp8, t);
CHECK(array_equal(out, in).item<bool>());
}
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});