diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index 20d9c60f6..b68091c98 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -120,7 +120,7 @@ Simd fp32_to_bits(Simd x) { struct ToFP8 { template Simd operator()(Simd f) { - uint32_t fp8_max = 1087 << 20; + uint32_t fp8_max = 543 << 21; auto denorm_mask = Simd(141 << 23); Simd f_bits; Simd f32 = f; diff --git a/mlx/backend/cuda/quantized/cuda_fp4.h b/mlx/backend/cuda/quantized/cuda_fp4.h index c107a38c7..10df45795 100644 --- a/mlx/backend/cuda/quantized/cuda_fp4.h +++ b/mlx/backend/cuda/quantized/cuda_fp4.h @@ -1,7 +1,22 @@ #pragma once struct __nv_fp8_e8m0 { - __device__ __nv_fp8_e8m0(uint8_t x) : __x(x) {} + __device__ __nv_fp8_e8m0(float x) { + if (!std::isfinite(x)) { + __x = 0xFF; + return; + } + if (x < 0.0f) { + __x = 0x00; + return; + } + float le = std::log2f(x); + int n = static_cast(std::nearbyintf(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + __x = static_cast(n + 127); + } __device__ operator float() { if (__x == 0xFF) { diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index bc50ffef9..0f979dfb0 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -49,13 +49,12 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { auto grid_dim_x = cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; - size_t out_index = tidx + grid_dim_x * size_t(tidy); - size_t in_index = out_index; - if (in_index >= size) { + size_t index = tidx + grid_dim_x * size_t(tidy); + if (index >= size) { return; } - float w_thread = w[in_index]; + float w_thread = w[index]; cg::greater max_op; auto warp = cg::tiled_partition(cg::this_thread_block()); @@ -70,21 +69,19 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { scale = float(s); // Write out the scales - size_t gindex = in_index / group_size; - if (in_index % group_size == 0) { + size_t gindex = index / group_size; + if (index % group_size == 0) { scales[gindex] = q_scale; } - uint8_t output = 0; - uint8_t val = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); - output = val; + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); if (bits == 4) { - uint8_t sval = warp.shfl_down(val, 1); + uint8_t sval = warp.shfl_down(output, 1); output |= sval << bits; } constexpr int pack_factor = bits == 8 ? 1 : 2; - if (out_index % pack_factor == 0) { - out[out_index / pack_factor] = output; + if (index % pack_factor == 0) { + out[index / pack_factor] = output; } } diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 16225e181..628a35ae9 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -29,7 +29,7 @@ make_jit_source( kernels/bf16_math.h kernels/complex.h kernels/defines.h) -make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) +make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 70faa1d24..cd743f0a8 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -6,6 +6,7 @@ set(BASE_HEADERS defines.h erf.h expm1f.h + fp8.h utils.h) function(build_kernel_base TARGET SRCFILE DEPS) @@ -109,7 +110,7 @@ if(NOT MLX_METAL_JIT) reduction/reduce_col.h reduction/reduce_row.h) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) - build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS}) + build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) diff --git a/mlx/backend/metal/kernels/fp4.h b/mlx/backend/metal/kernels/fp4.h new file mode 100644 index 000000000..40742cc31 --- /dev/null +++ b/mlx/backend/metal/kernels/fp4.h @@ -0,0 +1,56 @@ +#pragma once + +constexpr constant static float FP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +struct fp4_e2m1 { + fp4_e2m1(float x) { + if (metal::isnan(x)) { + bits = 0x7; + return; + } + + const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; + x = metal::abs(x); + + if (x > 5.0f) { + bits = 0x7; + } else if (x >= 3.5f) { + bits = 0x6; + } else if (x > 2.5f) { + bits = 0x5; + } else if (x >= 1.75f) { + bits = 0x4; + } else if (x > 1.25f) { + bits = 0x3; + } else if (x >= 0.75f) { + bits = 0x2; + } else if (x > 0.25f) { + bits = 0x1; + } else { + bits = 0x0; + } + bits |= sign_bit; + } + + operator float() { + return FP4_LUT[bits]; + } + + uint8_t bits; +}; diff --git a/mlx/backend/metal/kernels/fp8.h b/mlx/backend/metal/kernels/fp8.h new file mode 100644 index 000000000..4b1836a39 --- /dev/null +++ b/mlx/backend/metal/kernels/fp8.h @@ -0,0 +1,88 @@ +#pragma once + +inline float fp32_from_bits(uint32_t bits) { + return *(reinterpret_cast(&bits)); +} +inline float fp32_to_bits(float x) { + return *(reinterpret_cast(&x)); +} + +struct fp8_e4m3 { + template + fp8_e4m3(T f) { + // From PyTorch + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 + uint32_t fp8_max = 543 << 21; + uint32_t denorm_mask = 141 << 23; + uint32_t f_bits = fp32_to_bits(static_cast(f)); + uint32_t sign = f_bits & 0x80000000; + f_bits ^= sign; + if (f_bits >= fp8_max) { + // Default behavior saturates to min/max + bits = 0x7E; + } else { + if (f_bits < (121 << 23)) { + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + bits = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + f_bits += mant_odd; + bits = static_cast(f_bits >> 20); + } + } + bits |= static_cast(sign >> 24); + } + + operator float() { + // From PyTorch: + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 + uint32_t w = static_cast(bits) << 24; + uint32_t sign = w & 0x80000000; + uint32_t nonsign = w & 0x7FFFFFFF; + + uint32_t renorm_shift = metal::clz(nonsign); + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + + int32_t inf_nan_mask = + (static_cast(nonsign + 0x01000000) >> 8) & 0x7F800000; + int32_t zero_mask = static_cast(nonsign - 1) >> 31; + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); + } + + uint8_t bits; +}; + +struct fp8_e8m0 { + fp8_e8m0(float x) { + if (!metal::isfinite(x)) { + bits = 0xFF; + return; + } + if (x < 0.0f) { + bits = 0x00; + return; + } + float le = metal::log2(x); + int n = int(metal::round(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + bits = static_cast(n + 127); + } + + operator float() { + if (bits == 0xFF) { + return metal::numeric_limits::quiet_NaN(); + } + return metal::ldexp(1.0f, static_cast(bits) - 127); + } + + uint8_t bits; +}; diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h similarity index 94% rename from mlx/backend/metal/kernels/fp4_quantized.h rename to mlx/backend/metal/kernels/fp_quantized.h index 0b22dc1e5..2d2809bb7 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -59,28 +59,10 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { } } -constexpr constant static float MXFP4_LUT[16] = { - +0.0f, - +0.5f, - +1.0f, - +1.5f, - +2.0f, - +3.0f, - +4.0f, - +6.0f, - -0.0f, - -0.5f, - -1.0f, - -1.5f, - -2.0f, - -3.0f, - -4.0f, - -6.0f}; - template void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { if (simd_gid == 0 && simd_lid < 16) { - lut[simd_lid] = static_cast(MXFP4_LUT[simd_lid]); + lut[simd_lid] = static_cast(FP4_LUT[simd_lid]); } threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -1789,3 +1771,100 @@ template < } } } + +template +struct Quantize { + uint8_t operator()(float x) { + if constexpr (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(thread fp8_e4m3*)(&x)); + } else { + return float(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +[[kernel]] void fp_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device uint8_t* scales [[buffer(2)]], + uint2 tidx [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + size_t index = tidx.x + grid_dim.x * size_t(tidx.y); + + float scale; + float w_thread = w[index]; + if (use_mx_scale) { + scale = simd_max(abs(w_thread)); + } else { + float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); + float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); + scale = tidx.x < 16 ? w_max_l : w_max_r; + } + scale /= bits == 4 ? 6.0f : 448.0f; + + using ScaleType = metal::conditional_t; + auto s = ScaleType(scale); + uint8_t q_scale = s.bits; + scale = float(s); + + // Write out the scales and biases + size_t gindex = index / group_size; + if (index % group_size == 0) { + scales[gindex] = q_scale; + } + + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + if (bits == 4) { + uint8_t sval = simd_shuffle_down(output, 1); + output |= sval << bits; + } + constexpr int pack_factor = bits == 8 ? 1 : 2; + if (index % pack_factor == 0) { + out[index / pack_factor] = output; + } +} + +template +[[kernel]] void fp_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + constexpr int pack_factor = bits == 8 ? 1 : 2; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * pack_factor; + size_t gindex = oindex / group_size; + + out += oindex; + + using ScaleType = metal::conditional_t; + auto q_scale = ((device ScaleType*)(scales))[gindex]; + auto scale = float(q_scale); + + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} diff --git a/mlx/backend/metal/kernels/fp4_quantized.metal b/mlx/backend/metal/kernels/fp_quantized.metal similarity index 83% rename from mlx/backend/metal/kernels/fp4_quantized.metal rename to mlx/backend/metal/kernels/fp_quantized.metal index 6b2daf88c..bef26b786 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.metal +++ b/mlx/backend/metal/kernels/fp_quantized.metal @@ -4,7 +4,9 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized_utils.h" -#include "mlx/backend/metal/kernels/fp4_quantized.h" +#include "mlx/backend/metal/kernels/fp8.h" +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp_quantized.h" #define instantiate_quantized(name, type) \ instantiate_kernel( \ @@ -113,13 +115,33 @@ instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) +#define instantiate_quantize_dequantize(type, mode, group_size, bits) \ + instantiate_kernel( \ + mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ + fp_quantize, \ + type, \ + group_size, \ + bits) \ + instantiate_kernel( \ + mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ + fp_dequantize, \ + type, \ + group_size, \ + bits) + +#define instantiate_quantize_dequantize_modes(type) \ + instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \ + instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \ + instantiate_quantize_dequantize(type, "mxfp8", 32, 8) + #define instantiate_quantized_types(type) \ instantiate_quantized_all_batched(type) \ instantiate_quantized_all_quad(type) \ instantiate_quantized_all_splitk(type) \ instantiate_quantized_all_single(type) \ instantiate_quantized_all_aligned(type) \ - instantiate_quantized_all_rhs(type) + instantiate_quantized_all_rhs(type) \ + instantiate_quantize_dequantize_modes(type) instantiate_quantized_types(float) instantiate_quantized_types(bfloat16_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 423c07f66..327bb5a94 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -8,6 +8,7 @@ #include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" +#include "mlx/backend/metal/kernels/fp8.h" namespace { constant float inf = metal::numeric_limits::infinity(); @@ -439,63 +440,15 @@ complex64_t ArcTan::operator()(complex64_t x) { return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); }; -inline float fp32_from_bits(uint32_t bits) { - return *(reinterpret_cast(&bits)); -} -inline float fp32_to_bits(float x) { - return *(reinterpret_cast(&x)); -} - struct ToFP8 { template uint8_t operator()(T f) { - // From PyTorch - // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 - uint32_t fp8_max = 1087 << 20; - uint32_t denorm_mask = 141 << 23; - uint32_t f_bits = fp32_to_bits(static_cast(f)); - uint8_t result = 0u; - uint32_t sign = f_bits & 0x80000000; - f_bits ^= sign; - if (f_bits >= fp8_max) { - // Default behavior saturates to min/max - result = 0x7E; - } else { - if (f_bits < (121 << 23)) { - f_bits = - fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); - result = static_cast(f_bits - denorm_mask); - } else { - // resulting mantissa is odd - uint8_t mant_odd = (f_bits >> 20) & 1; - f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; - f_bits += mant_odd; - result = static_cast(f_bits >> 20); - } - } - result |= static_cast(sign >> 24); - return result; + return fp8_e4m3(f).bits; } }; struct FromFP8 { float operator()(uint8_t x) { - // From PyTorch: - // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 - uint32_t w = static_cast(x) << 24; - uint32_t sign = w & 0x80000000; - uint32_t nonsign = w & 0x7FFFFFFF; - - uint32_t renorm_shift = metal::clz(nonsign); - renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; - - int32_t inf_nan_mask = - (static_cast(nonsign + 0x01000000) >> 8) & 0x7F800000; - int32_t zero_mask = static_cast(nonsign - 1) >> 31; - uint32_t result = sign | - ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | - inf_nan_mask) & - ~zero_mask); - return fp32_from_bits(result); + return float(*(thread fp8_e4m3*)(&x)); } }; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 328669a92..3a5b9cec3 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1045,26 +1045,31 @@ void fast::Quantize::eval_gpu( compute_encoder.set_input_array(w, 0); if (dequantize_) { auto scales = ensure_row_contiguous(inputs[1], d, s); - auto biases = ensure_row_contiguous(inputs[2], d, s); compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); compute_encoder.set_output_array(out, 3); + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous(inputs[2], d, s); + compute_encoder.set_input_array(biases, 2); + } } else { auto& scales = outputs[1]; - auto& biases = outputs[2]; scales.set_data(allocator::malloc(scales.nbytes())); - biases.set_data(allocator::malloc(biases.nbytes())); compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(scales, 2); - compute_encoder.set_output_array(biases, 3); + if (mode_ == QuantizationMode::Affine) { + auto& biases = outputs[2]; + biases.set_data(allocator::malloc(biases.nbytes())); + compute_encoder.set_output_array(biases, 3); + } } auto type_string = dequantize_ ? get_type_string(out.dtype()) : get_type_string(w_pre.dtype()); + auto mode = quantization_mode_to_string(mode_); std::string kname; concatenate( kname, - dequantize_ ? "affine_dequantize" : "affine_quantize", + mode + (dequantize_ ? "_dequantize" : "_quantize"), "_", type_string, "_gs_", @@ -1075,7 +1080,7 @@ void fast::Quantize::eval_gpu( d, kname, dequantize_ ? "dequantize" : "quantize", - "affine", + mode, type_string, group_size_, bits_); @@ -1088,7 +1093,8 @@ void fast::Quantize::eval_gpu( int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 : bits_ == 6 ? 4 : 8 / bits_; - int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; + int per_thread = + dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1); size_t nthreads = dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f6b0ee04c..613cc5a1e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2,7 +2,6 @@ // Required for using M_PI in MSVC. #define _USE_MATH_DEFINES - #include #include #include @@ -4259,8 +4258,11 @@ std::vector fp_quantize( } else { // convert to e8m0 auto z = array(0, scales.dtype()); - scales = - where(equal(scales, z, s), z, astype(log2(scales, s), int32, s), s); + scales = where( + equal(scales, z, s), + z, + astype(round(log2(scales, s), s), int32, s), + s); wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); scales = astype(add(scales, array(127, int32), s), uint8, s); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index d828cc69b..ad8dffab6 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -92,7 +92,6 @@ class TestQuantized(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.quantize(w, group_size=32, bits=7, mode="mxfp8") - w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8") with self.assertRaises(ValueError): @@ -102,7 +101,8 @@ class TestQuantized(mlx_tests.MLXTestCase): mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8") w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8") - self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2)) + + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1)) # test quantize/dequantize 0s a = mx.zeros((256, 512))