mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add cuda
This commit is contained in:
@@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
@@ -52,6 +52,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||
@@ -170,11 +171,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
@@ -334,4 +334,17 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T>
|
||||
__device__ uint8_t operator()(T x) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
return float(*(__nv_fp8_e4m3*)(&x));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
19
mlx/backend/cuda/quantized/convert_fp8.cu
Normal file
19
mlx/backend/cuda/quantized/convert_fp8.cu
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
void fast::ConvertFP8::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (to_fp8_) {
|
||||
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
|
||||
} else {
|
||||
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, ToFP8>) {
|
||||
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, FromFP8>) {
|
||||
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
// Required for using M_PI_2 in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@@ -4038,6 +4037,7 @@ TEST_CASE("test fp8 conversion") {
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
|
||||
auto in_fp8 = to_fp8(in);
|
||||
auto out = from_fp8(in_fp8, t);
|
||||
CHECK(array_equal(out, in).item<bool>());
|
||||
}
|
||||
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
|
||||
|
||||
Reference in New Issue
Block a user