From 012fb220a19035c2ca28ddcf8b2d204eaaf7e761 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 11 Dec 2025 15:11:25 +0100 Subject: [PATCH] fp quantize (#2892) --- mlx/backend/cuda/quantized/fp_quantize.cu | 88 ++--- mlx/backend/cuda/quantized/mxfp8_quantize.cuh | 32 ++ mlx/backend/cuda/quantized/nvfp4_quantize.cuh | 334 ++++++++++++++++++ .../cuda/quantized/quantized_utils.cuh | 16 + mlx/backend/cuda/steel/tiles.cuh | 23 +- mlx/backend/cuda/vector_types.cuh | 48 +++ 6 files changed, 479 insertions(+), 62 deletions(-) create mode 100644 mlx/backend/cuda/quantized/mxfp8_quantize.cuh create mode 100644 mlx/backend/cuda/quantized/nvfp4_quantize.cuh create mode 100644 mlx/backend/cuda/vector_types.cuh diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 45c61baf5..44e01d0c4 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -2,7 +2,11 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" +#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" #include "mlx/backend/cuda/quantized/quantized.h" +#include "mlx/backend/cuda/quantized/quantized_utils.cuh" +#include "mlx/backend/cuda/vector_types.cuh" #include "mlx/dtype_utils.h" #include @@ -13,17 +17,6 @@ namespace mlx::core { namespace cu { -template -struct Quantize { - __device__ uint8_t operator()(float x) { - if constexpr (bits == 8) { - return __nv_fp8_e4m3(x).__x; - } else { - return __nv_fp4_e2m1(x).__x; - } - } -}; - template struct Dequantize { __device__ float operator()(uint8_t x) { @@ -37,29 +30,40 @@ struct Dequantize { namespace cg = cooperative_groups; -template -__global__ void -fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { +template +__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) { + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; // reserved bits for future use auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); - auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidy = block_idx.y * block_size.y + idx_in_block.y; + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - auto grid_dim_x = - cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; - size_t index = tidx + grid_dim_x * size_t(tidy); - if (index >= size) { + size_t thread_idx = tidx + grid_dim_x * size_t(tidy); + size_t base_idx = thread_idx * group_size; + + if (base_idx >= size) { return; } - float w_thread = w[index]; + auto w_tile = load_vector(w, thread_idx); + float scale = 0.0f; - cg::greater max_op; - auto warp = cg::tiled_partition(cg::this_thread_block()); + Tx2 amax_2x = Tx2{0.0f, 0.0f}; + +#pragma unroll + for (int i = 0; i < group_size; i += 2) { + auto pair = Tx2{w_tile[i], w_tile[i + 1]}; + abs_max_x2(amax_2x, amax_2x, pair); + } + + scale = static_cast( + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y)))); - float scale = cg::reduce(warp, abs(w_thread), max_op); scale /= bits == 4 ? 6.0f : 448.0f; // Convert to mx scale or nv scale using ScaleType = @@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { uint8_t q_scale = s.__x; scale = float(s); - // Write out the scales - size_t gindex = index / group_size; - if (index % group_size == 0) { - scales[gindex] = q_scale; - } + scales[thread_idx] = q_scale; + constexpr int elem_per_byte = bits == 8 ? 1 : 2; + AlignedVector quantized; - uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); - if (bits == 4) { - uint8_t sval = warp.shfl_down(output, 1); - output |= sval << bits; - } - constexpr int pack_factor = bits == 8 ? 1 : 2; - if (index % pack_factor == 0) { - out[index / pack_factor] = output; +#pragma unroll + for (int i = 0; i < group_size / 4; i++) { + Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast(&quantized[i * 4]) = quantized_val; + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast(&quantized[i * 2]) = quantized_val; + } } + store_vector(out, thread_idx, quantized); } template @@ -142,15 +149,16 @@ void fp_quantize( dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { - auto kernel = cu::fp_quantize; + auto kernel = cu::fp_quantize; if (bits == 8) { - kernel = cu::fp_quantize; + kernel = cu::fp_quantize; } else if (group_size == 16) { - kernel = cu::fp_quantize; + kernel = cu::fp_quantize; } bool large = w.size() > UINT_MAX; auto [num_blocks, block_dims] = - get_launch_args(w.size(), w.shape(), w.strides(), large); + get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); + enc.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/quantized/mxfp8_quantize.cuh b/mlx/backend/cuda/quantized/mxfp8_quantize.cuh new file mode 100644 index 000000000..1edf36417 --- /dev/null +++ b/mlx/backend/cuda/quantized/mxfp8_quantize.cuh @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include +#include "mlx/backend/cuda/vector_types.cuh" + +namespace mlx::core::cu { + +// TODO implement fast path +template +__device__ __forceinline__ uint32_t +scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t input, const float scale) { + uint32_t out_fp8x4 = 0; + float4 scaled; + scaled.x = static_cast(input.x) * scale; + scaled.y = static_cast(input.y) * scale; + scaled.z = static_cast(input.z) * scale; + scaled.w = static_cast(input.w) * scale; + out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x; + return out_fp8x4; +} + +// Place holder for future fast path implementation +template +__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4( + const Vector4_t input, + const float scale, + uint32_t rbits) { + return scale_cvt_Tx4_to_fp8x4_fallback(input, scale); +} +} // namespace mlx::core::cu \ No newline at end of file diff --git a/mlx/backend/cuda/quantized/nvfp4_quantize.cuh b/mlx/backend/cuda/quantized/nvfp4_quantize.cuh new file mode 100644 index 000000000..da0df4293 --- /dev/null +++ b/mlx/backend/cuda/quantized/nvfp4_quantize.cuh @@ -0,0 +1,334 @@ +#pragma once + +#include +#include +#include +#include "mlx/backend/cuda/vector_types.cuh" + +namespace mlx::core::cu { + +using bf16x4 = Vector4_t<__nv_bfloat16>; +using fp16x4 = Vector4_t<__half>; +using f32x4 = Vector4_t; + +template +__device__ __forceinline__ uint16_t +scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t input, const float scale) { + // Fallback implementation for architectures that do not support cvt + // instructions or for cuda versions with no fp4 support (< 12.8) -> scalar + uint16_t out_fp4x4 = 0; + fp32x4 scaled; + scaled.x = static_cast(input.x) * scale; + scaled.y = static_cast(input.y) * scale; + scaled.z = static_cast(input.z) * scale; + scaled.w = static_cast(input.w) * scale; + uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x; + uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x; + uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x; + uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x; + out_fp4x4 = (static_cast(q3) << 12) | + (static_cast(q2) << 8) | (static_cast(q1) << 4) | + static_cast(q0); + return out_fp4x4; +} + +#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ + defined(__CUDA_ARCH_SPECIFIC__) + +__device__ __forceinline__ uint16_t +scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) { + uint16_t out_fp4x4 = 0; + asm volatile( + "{\n" + ".reg.b16 x0_bf16; \n\t" // first bf16 + ".reg.b16 x1_bf16; \n\t" // second bf16 + ".reg.b16 x2_bf16; \n\t" // third bf16 + ".reg.b16 x3_bf16; \n\t" // fourth bf16 + ".reg.b32 x0; \n\t" // to hold scaled first + ".reg.b32 x1; \n\t" // to hold scaled second + ".reg.b32 x2; \n\t" // to hold scaled third + ".reg.b32 x3; \n\t" // to hold scaled fourth + ".reg.b64 x01; \n\t" // to hold vector mul + ".reg.b64 x23; \n\t" + ".reg.b8 q0; \n\t" // output byte fp4x2 (first pair) + ".reg.b8 q1; \n\t" // output byte fp4x2 (second pair) + "mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16 + "cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32 + "cvt.f32.bf16 x1, x1_bf16; \n\t" + "cvt.f32.bf16 x2, x2_bf16; \n\t" + "cvt.f32.bf16 x3, x3_bf16; \n\t" + "mov.b64 x01, {x0, x1}; \n\t" + "mul.f32x2 x01, x01, %2; \n\t" // scale first pair + "mov.b64 x23, {x2, x3}; \n\t" + "mul.f32x2 x23, x23, %2; \n\t" // scale second pair + "mov.b64 {x0, x1}, x01; \n\t" + "mov.b64 {x2, x3}, x23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first + // pair + "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second + // pair + "mov.b16 %0, {q0, q1}; \n\t" // pack to output + "}" + : "=h"(out_fp4x4) + : "l"(reinterpret_cast(input_bf16x4)), + "l"(reinterpret_cast( + scale))); // here cast is needed becuase an asm operand must have + // scalar type + return out_fp4x4; +} + +__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs( + const bf16x4 input_bf16x4, + const float2 scale, + uint32_t rbits) { + uint16_t out_fp4x4 = 0; + asm volatile( + "{\n" + ".reg.b16 x0_bf16; \n\t" + ".reg.b16 x1_bf16; \n\t" + ".reg.b16 x2_bf16; \n\t" + ".reg.b16 x3_bf16; \n\t" + ".reg.b32 x0; \n\t" + ".reg.b32 x1; \n\t" + ".reg.b32 x2; \n\t" + ".reg.b32 x3; \n\t" + ".reg.b64 x01; \n\t" + ".reg.b64 x23; \n\t" + ".reg.b16 q0; \n\t" + "mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" + "cvt.f32.bf16 x0, x0_bf16; \n\t" + "cvt.f32.bf16 x1, x1_bf16; \n\t" + "cvt.f32.bf16 x2, x2_bf16; \n\t" + "cvt.f32.bf16 x3, x3_bf16; \n\t" + "mov.b64 x01, {x0, x1}; \n\t" + "mul.f32x2 x01, x01, %2; \n\t" + "mov.b64 x23, {x2, x3}; \n\t" + "mul.f32x2 x23, x23, %2; \n\t" + "mov.b64 {x0, x1}, x01; \n\t" + "mov.b64 {x2, x3}, x23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t" + "}" + : "=h"(out_fp4x4) + : "l"(reinterpret_cast(input_bf16x4)), + "l"(reinterpret_cast(scale)), + "r"(rbits)); + return out_fp4x4; +} + +__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn( + const float2 input_fp32x2_0, + const float2 input_fp32x2_1, + const float2 scale) { + uint16_t out_fp4x4 = 0; + asm volatile( + "{\n" + ".reg.b32 x0; \n\t" + ".reg.b32 x1; \n\t" + ".reg.b32 x2; \n\t" + ".reg.b32 x3; \n\t" + ".reg.b64 x01; \n\t" + ".reg.b64 x23; \n\t" + ".reg.b8 q0; \n\t" + ".reg.b8 q1; \n\t" + "mov.b64 x01, {%1, %2}; \n\t" + "mul.f32x2 x01, x01, %5; \n\t" + "mov.b64 x23, {%3, %4}; \n\t" + "mul.f32x2 x23, x23, %5; \n\t" + "mov.b64 {x0, x1}, x01; \n\t" + "mov.b64 {x2, x3}, x23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" + "mov.b16 %0, {q0, q1}; \n\t" + "}" + : "=h"(out_fp4x4) + : "f"(input_fp32x2_0.x), + "f"(input_fp32x2_0.y), + "f"(input_fp32x2_1.x), + "f"(input_fp32x2_1.y), + "l"(reinterpret_cast(scale))); + return out_fp4x4; +} + +__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs( + const float2 input_fp32x2_0, + const float2 input_fp32x2_1, + const float2 scale, + uint32_t rbits) { + uint16_t out_fp4x4 = 0; + asm volatile( + "{\n" + ".reg.b32 x0; \n\t" + ".reg.b32 x1; \n\t" + ".reg.b32 x2; \n\t" + ".reg.b32 x3; \n\t" + ".reg.b64 x01; \n\t" + ".reg.b64 x23; \n\t" + ".reg.b16 q0; \n\t" + "mov.b64 x01, {%1, %2}; \n\t" + "mul.f32x2 x01, x01, %5; \n\t" + "mov.b64 x23, {%3, %4}; \n\t" + "mul.f32x2 x23, x23, %5; \n\t" + "mov.b64 {x0, x1}, x01; \n\t" + "mov.b64 {x2, x3}, x23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t" + "}" + : "=h"(out_fp4x4) + : "f"(input_fp32x2_0.x), + "f"(input_fp32x2_0.y), + "f"(input_fp32x2_1.x), + "f"(input_fp32x2_1.y), + "l"(reinterpret_cast(scale)), + "r"(rbits)); + return out_fp4x4; +} + +__device__ __forceinline__ uint16_t +scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) { + uint16_t out_fp4x4 = 0; + asm volatile( + "{\n" + ".reg.b16 x0_fp16; \n\t" + ".reg.b16 x1_fp16; \n\t" + ".reg.b16 x2_fp16; \n\t" + ".reg.b16 x3_fp16; \n\t" + ".reg.b32 x0; \n\t" + ".reg.b32 x1; \n\t" + ".reg.b32 x2; \n\t" + ".reg.b32 x3; \n\t" + ".reg.b64 x01; \n\t" + ".reg.b64 x23; \n\t" + ".reg.b8 q0; \n\t" + ".reg.b8 q1; \n\t" + "mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t" + "cvt.f32.f16 x0, x0_fp16; \n\t" + "cvt.f32.f16 x1, x1_fp16; \n\t" + "cvt.f32.f16 x2, x2_fp16; \n\t" + "cvt.f32.f16 x3, x3_fp16; \n\t" + "mov.b64 x01, {x0, x1}; \n\t" + "mul.f32x2 x01, x01, %2; \n\t" + "mov.b64 x23, {x2, x3}; \n\t" + "mul.f32x2 x23, x23, %2; \n\t" + "mov.b64 {x0, x1}, x01; \n\t" + "mov.b64 {x2, x3}, x23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" + "mov.b16 %0, {q0, q1}; \n\t" + "}" + : "=h"(out_fp4x4) + : "l"(reinterpret_cast(input_fp16x4)), + "l"(reinterpret_cast(scale))); + return out_fp4x4; +} + +__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs( + const fp16x4 input_fp16x4, + const float2 scale, + uint32_t rbits) { + uint16_t out_fp4x4 = 0; + asm volatile( + "{\n" + ".reg.b16 x0_fp16; \n\t" + ".reg.b16 x1_fp16; \n\t" + ".reg.b16 x2_fp16; \n\t" + ".reg.b16 x3_fp16; \n\t" + ".reg.b32 x0; \n\t" + ".reg.b32 x1; \n\t" + ".reg.b32 x2; \n\t" + ".reg.b32 x3; \n\t" + ".reg.b64 x01; \n\t" + ".reg.b64 x23; \n\t" + ".reg.b16 q0; \n\t" + "mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t" + "cvt.f32.f16 x0, x0_fp16; \n\t" + "cvt.f32.f16 x1, x1_fp16; \n\t" + "cvt.f32.f16 x2, x2_fp16; \n\t" + "cvt.f32.f16 x3, x3_fp16; \n\t" + "mov.b64 x01, {x0, x1}; \n\t" + "mul.f32x2 x01, x01, %2; \n\t" + "mov.b64 x23, {x2, x3}; \n\t" + "mul.f32x2 x23, x23, %2; \n\t" + "mov.b64 {x0, x1}, x01; \n\t" + "mov.b64 {x2, x3}, x23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t" + "}" + : "=h"(out_fp4x4) + : "l"(reinterpret_cast(input_fp16x4)), + "l"(reinterpret_cast(scale)), + "r"(rbits)); + return out_fp4x4; +} + +template +__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4( + const bf16x4 input, + const float scale, + uint32_t rbits) { + float2 scale_fp32x2 = make_float2(scale, scale); + if constexpr (USE_SR) { + return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits); + } else { + return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2); + } +} + +template +__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4( + const fp16x4 input, + const float scale, + uint32_t rbits) { + float2 scale_fp32x2 = make_float2(scale, scale); + if constexpr (USE_SR) { + return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits); + } else { + return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2); + } +} + +template +__device__ __forceinline__ uint16_t +scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) { + float2 scale_fp32x2 = make_float2(scale, scale); + float2 input_fp32x2_0 = make_float2(input.x, input.y); + float2 input_fp32x2_1 = make_float2(input.z, input.w); + + if constexpr (USE_SR) { + return scale_cvt_fp32x4_to_fp4x4_rs( + input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits); + } else { + return scale_cvt_fp32x4_to_fp4x4_rn( + input_fp32x2_0, input_fp32x2_1, scale_fp32x2); + } +} + +template +__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast( + const Vector4_t input, + const float scale, + uint32_t rbits) { + if constexpr (std::is_same::value) { + return scale_cvt_bf16x4_to_fp4x4(input, scale, rbits); + } else if constexpr (std::is_same::value) { + return scale_cvt_fp16x4_to_fp4x4(input, scale, rbits); + } else { + return scale_cvt_f32x4_to_fp4x4(input, scale, rbits); + } +} +#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && + // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) + +template +__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4( + const Vector4_t input, + const float scale, + uint32_t rbits) { +#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ + (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) + return scale_cvt_Tx4_to_fp4x4_fast(input, scale, rbits); +#else + static_assert( + !USE_SR, + "Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000."); + return scale_cvt_Tx4_to_fp4x4_fallback(input, scale); +#endif +} +} // namespace mlx::core::cu \ No newline at end of file diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index c6a85527c..e589c9705 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() { return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); } +template +__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) { + if constexpr ( + (std::is_same::value) || + (std::is_same::value)) { + T a = x1; + T b = x2; + out = __hmax2(__habs2(a), __habs2(b)); + } else if constexpr (std::is_same::value) { + float2 a = x1; + float2 b = x2; + out.x = fmaxf(fabsf(a.x), fabsf(b.x)); + out.y = fmaxf(fabsf(a.y), fabsf(b.y)); + } +} + } // namespace cu template diff --git a/mlx/backend/cuda/steel/tiles.cuh b/mlx/backend/cuda/steel/tiles.cuh index be6c46648..458380287 100644 --- a/mlx/backend/cuda/steel/tiles.cuh +++ b/mlx/backend/cuda/steel/tiles.cuh @@ -3,31 +3,10 @@ #pragma once #include "mlx/backend/cuda/steel/utils.cuh" +#include "mlx/backend/cuda/vector_types.cuh" namespace mlx::core::cu { -// Map types to their vector of 2 type float -> float2, double -> double2 etc -template -struct Vector2; -template <> -struct Vector2 { - using type = double2; -}; -template <> -struct Vector2 { - using type = float2; -}; -template <> -struct Vector2<__half> { - using type = __half2; -}; -template <> -struct Vector2<__nv_bfloat16> { - using type = __nv_bfloat162; -}; -template -using Vector2_t = typename Vector2::type; - /** * The basic building block for Ampere mmas. A 16x16 tile distributed across * the warp. diff --git a/mlx/backend/cuda/vector_types.cuh b/mlx/backend/cuda/vector_types.cuh new file mode 100644 index 000000000..2a7c09223 --- /dev/null +++ b/mlx/backend/cuda/vector_types.cuh @@ -0,0 +1,48 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +template +struct Vector2; + +template <> +struct Vector2 { + using type = double2; +}; + +template <> +struct Vector2 { + using type = float2; +}; + +template <> +struct Vector2<__half> { + using type = __half2; +}; + +template <> +struct Vector2<__nv_bfloat16> { + using type = __nv_bfloat162; +}; + +template +using Vector2_t = typename Vector2::type; + +template +struct Vector4 { + T x, y, z, w; +}; + +template +using Vector4_t = Vector4; + +using bf16x4 = Vector4_t<__nv_bfloat16>; +using fp16x4 = Vector4_t<__half>; +using fp32x4 = Vector4_t; + +} // namespace mlx::core::cu