From ec72b444172264e3238c9b35e5b1d188b2d06168 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 28 Oct 2025 16:23:12 -0700 Subject: [PATCH] Add quantize/dequantize for mxfp8 and nvfp4 (#2688) * Add quantize/dequantize slow path for mxfp8 and nvfp4 * fast cuda kernel for mx/nv quantization * fallback for cuda < 12.8 (#2697) * format (#2700) * fix (#2701) * metal kernels * docs * fix jit * add default bits and group sizes * improve quant docs * fix output type of mxfp4 matmuls --- mlx/backend/cpu/unary_ops.h | 2 +- mlx/backend/cuda/CMakeLists.txt | 6 + mlx/backend/cuda/quantized/affine_quantize.cu | 2 +- mlx/backend/cuda/quantized/cuda_fp4.h | 83 +++ mlx/backend/cuda/quantized/fp_quantize.cu | 216 ++++++++ mlx/backend/cuda/quantized/quantized.cpp | 19 +- mlx/backend/cuda/quantized/quantized.h | 18 + mlx/backend/metal/CMakeLists.txt | 5 +- mlx/backend/metal/jit/includes.h | 2 +- mlx/backend/metal/jit_kernels.cpp | 51 +- mlx/backend/metal/kernels/CMakeLists.txt | 4 +- mlx/backend/metal/kernels/fp4.h | 56 ++ mlx/backend/metal/kernels/fp4_quantized.metal | 127 ----- mlx/backend/metal/kernels/fp8.h | 88 ++++ .../{fp4_quantized.h => fp_quantized.h} | 298 +++++++---- mlx/backend/metal/kernels/fp_quantized.metal | 147 ++++++ mlx/backend/metal/kernels/unary_ops.h | 53 +- mlx/backend/metal/quantized.cpp | 33 +- mlx/fast_primitives.h | 2 +- mlx/ops.cpp | 487 +++++++++++------- mlx/ops.h | 17 +- mlx/primitives.cpp | 32 +- mlx/primitives.h | 6 +- python/src/ops.cpp | 116 +++-- python/tests/test_quantized.py | 118 ++++- 25 files changed, 1400 insertions(+), 588 deletions(-) create mode 100644 mlx/backend/cuda/quantized/cuda_fp4.h create mode 100644 mlx/backend/cuda/quantized/fp_quantize.cu create mode 100644 mlx/backend/metal/kernels/fp4.h delete mode 100644 mlx/backend/metal/kernels/fp4_quantized.metal create mode 100644 mlx/backend/metal/kernels/fp8.h rename mlx/backend/metal/kernels/{fp4_quantized.h => fp_quantized.h} (89%) create mode 100644 mlx/backend/metal/kernels/fp_quantized.metal diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index 20d9c60f6..b68091c98 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -120,7 +120,7 @@ Simd fp32_to_bits(Simd x) { struct ToFP8 { template Simd operator()(Simd f) { - uint32_t fp8_max = 1087 << 20; + uint32_t fp8_max = 543 << 21; auto denorm_mask = Simd(141 << 23); Simd f_bits; Simd f32 = f; diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index eabee94f2..7f8f1aade 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -51,6 +51,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) @@ -58,6 +59,11 @@ target_sources( add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) +# fp4 is not available on < 12.8 +if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0) + target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/) +endif() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) diff --git a/mlx/backend/cuda/quantized/affine_quantize.cu b/mlx/backend/cuda/quantized/affine_quantize.cu index 94e67d135..a64597a88 100644 --- a/mlx/backend/cuda/quantized/affine_quantize.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -306,7 +306,7 @@ void affine_dequantize( enc.set_input_array(scales); enc.set_input_array(biases); enc.set_output_array(w); - dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) { + dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) { dispatch_groups(group_size_, [&](auto group_size) { dispatch_bits(bits_, [&](auto bits) { using T = cuda_type_t; diff --git a/mlx/backend/cuda/quantized/cuda_fp4.h b/mlx/backend/cuda/quantized/cuda_fp4.h new file mode 100644 index 000000000..10df45795 --- /dev/null +++ b/mlx/backend/cuda/quantized/cuda_fp4.h @@ -0,0 +1,83 @@ +#pragma once + +struct __nv_fp8_e8m0 { + __device__ __nv_fp8_e8m0(float x) { + if (!std::isfinite(x)) { + __x = 0xFF; + return; + } + if (x < 0.0f) { + __x = 0x00; + return; + } + float le = std::log2f(x); + int n = static_cast(std::nearbyintf(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + __x = static_cast(n + 127); + } + + __device__ operator float() { + if (__x == 0xFF) { + return std::numeric_limits::quiet_NaN(); + } + return std::ldexp(1.0f, static_cast(__x) - 127); + } + + uint8_t __x{0}; +}; + +struct __nv_fp4_e2m1 { + __device__ __nv_fp4_e2m1(float x) { + if (std::isnan(x)) { + __x = 0x7; + return; + } + + const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0; + x = std::abs(x); + + if (x > 5.0f) { + __x = 0x7; + } else if (x >= 3.5f) { + __x = 0x6; + } else if (x > 2.5f) { + __x = 0x5; + } else if (x >= 1.75f) { + __x = 0x4; + } else if (x > 1.25f) { + __x = 0x3; + } else if (x >= 0.75f) { + __x = 0x2; + } else if (x > 0.25f) { + __x = 0x1; + } else { + __x = 0x0; + } + __x |= sign_bit; + } + + __device__ operator float() { + static const float LUT[16] = { + 0.0f, + 0.5f, + 1.0f, + 1.5f, + 2.0f, + 3.0f, + 4.0f, + 6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + + return LUT[__x]; + } + uint8_t __x{0}; +}; diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu new file mode 100644 index 000000000..0f979dfb0 --- /dev/null +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -0,0 +1,216 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/quantized.h" +#include "mlx/dtype_utils.h" + +#include +#include +#include +#include + +namespace mlx::core { +namespace cu { + +template +struct Quantize { + __device__ uint8_t operator()(float x) { + if constexpr (bits == 8) { + return __nv_fp8_e4m3(x).__x; + } else { + return __nv_fp4_e2m1(x).__x; + } + } +}; + +template +struct Dequantize { + __device__ float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(__nv_fp8_e4m3*)(&x)); + } else { + return float(*(__nv_fp4_e2m1*)(&x)); + } + } +}; + +namespace cg = cooperative_groups; + +template +__global__ void +fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = + cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; + size_t index = tidx + grid_dim_x * size_t(tidy); + if (index >= size) { + return; + } + + float w_thread = w[index]; + + cg::greater max_op; + auto warp = cg::tiled_partition(cg::this_thread_block()); + + float scale = cg::reduce(warp, abs(w_thread), max_op); + scale /= bits == 4 ? 6.0f : 448.0f; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale); + uint8_t q_scale = s.__x; + scale = float(s); + + // Write out the scales + size_t gindex = index / group_size; + if (index % group_size == 0) { + scales[gindex] = q_scale; + } + + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + if (bits == 4) { + uint8_t sval = warp.shfl_down(output, 1); + output |= sval << bits; + } + constexpr int pack_factor = bits == 8 ? 1 : 2; + if (index % pack_factor == 0) { + out[index / pack_factor] = output; + } +} + +template +__global__ void +fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = + cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; + + constexpr int pack_factor = bits == 8 ? 1 : 2; + size_t offset = tidx + grid_dim_x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + using ScaleType = + std::conditional_t; + auto scale = float(((ScaleType*)(scales))[gindex]); + + out += oindex; + + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} + +} // namespace cu + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize; + if (bits == 8) { + kernel = cu::fp_quantize; + } else if (group_size == 16) { + kernel = cu::fp_quantize; + } + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = + get_launch_args(w.size(), w.shape(), w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + w.data(), + wq.data(), + scales.data(), + w.size()); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); +} + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s) { + constexpr int uint8_per_uint32 = 4; + int packs_per_int = 8 / bits; + + size_t size = w.size() / packs_per_int; + bool large = size > UINT_MAX; + auto grid_shape = w.shape(); + grid_shape.back() *= uint8_per_uint32; + + enc.set_input_array(wq); + enc.set_input_array(scales); + enc.set_output_array(w); + dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_dequantize; + if (bits == 8) { + kernel = cu::fp_dequantize; + } else if (group_size == 16) { + kernel = cu::fp_dequantize; + } + auto [num_blocks, block_dims] = + get_launch_args(size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + wq.data(), + scales.data(), + w.data(), + w.size()); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not dequantize to output with type float64."); + } + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 71c687d85..58710834f 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu( if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); - auto biases = ensure_row_contiguous(inputs[2], enc, s); auto& w = outputs[0]; w.set_data(allocator::malloc(w.nbytes())); - affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous(inputs[2], enc, s); + affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); + } else { + fp_dequantize(wq, scales, w, group_size_, bits_, enc, s); + } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); auto& wq = outputs[0]; auto& scales = outputs[1]; - auto& biases = outputs[2]; wq.set_data(allocator::malloc(wq.nbytes())); scales.set_data(allocator::malloc(scales.nbytes())); - biases.set_data(allocator::malloc(biases.nbytes())); - - affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); + if (mode_ == QuantizationMode::Affine) { + auto& biases = outputs[2]; + biases.set_data(allocator::malloc(biases.nbytes())); + affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); + } else { + fp_quantize(w, wq, scales, group_size_, bits_, enc, s); + } } } diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index ec6a08000..4f1980a9c 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -24,4 +24,22 @@ void affine_dequantize( cu::CommandEncoder& enc, const Stream& s); +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 16225e181..0fd1834f6 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -29,7 +29,7 @@ make_jit_source( kernels/bf16_math.h kernels/complex.h kernels/defines.h) -make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) +make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) @@ -81,7 +81,8 @@ if(MLX_METAL_JIT) make_jit_source(quantized_utils) make_jit_source(quantized kernels/quantized_utils.h) - make_jit_source(fp4_quantized kernels/quantized_utils.h) + make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h + kernels/fp4.h) make_jit_source(gemv_masked) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index f3b57c7f9..c12e576c1 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -24,7 +24,7 @@ const char* hadamard(); const char* logsumexp(); const char* quantized_utils(); const char* quantized(); -const char* fp4_quantized(); +const char* fp_quantized(); const char* ternary(); const char* scan(); const char* scatter_axis(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index e70420cf8..de391abc9 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel( metal::utils(), metal::gemm(), metal::quantized_utils(), - (mode == "affine") ? metal::quantized() : metal::fp4_quantized(), + (mode == "affine") ? metal::quantized() : metal::fp_quantized(), template_def); return kernel_source; }); @@ -856,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm()); - if (mode == "affine") { - concatenate( - kernel_source, - metal::quantized(), - get_template_definition( - lib_name, - mode + "_gather_qmm_rhs", - get_type_string(x.dtype()), - group_size, - bits, - bm, - bn, - bk, - wm, - wn, - transpose)); - } else { - concatenate( - kernel_source, - metal::fp4_quantized(), - get_template_definition( - lib_name, - mode + "_gather_qmm_rhs", - get_type_string(x.dtype()), - group_size, - "uint8_t", - bm, - bn, - bk, - wm, - wn, - transpose)); - } + bool is_affine = mode == "affine"; + concatenate( + kernel_source, + is_affine ? metal::quantized() : metal::fp_quantized(), + get_template_definition( + lib_name, + (is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"), + get_type_string(x.dtype()), + group_size, + bits, + bm, + bn, + bk, + wm, + wn, + transpose)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 70faa1d24..69ac2a5e9 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -6,6 +6,7 @@ set(BASE_HEADERS defines.h erf.h expm1f.h + fp8.h utils.h) function(build_kernel_base TARGET SRCFILE DEPS) @@ -109,7 +110,8 @@ if(NOT MLX_METAL_JIT) reduction/reduce_col.h reduction/reduce_row.h) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) - build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS}) + build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h + ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) diff --git a/mlx/backend/metal/kernels/fp4.h b/mlx/backend/metal/kernels/fp4.h new file mode 100644 index 000000000..40742cc31 --- /dev/null +++ b/mlx/backend/metal/kernels/fp4.h @@ -0,0 +1,56 @@ +#pragma once + +constexpr constant static float FP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +struct fp4_e2m1 { + fp4_e2m1(float x) { + if (metal::isnan(x)) { + bits = 0x7; + return; + } + + const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; + x = metal::abs(x); + + if (x > 5.0f) { + bits = 0x7; + } else if (x >= 3.5f) { + bits = 0x6; + } else if (x > 2.5f) { + bits = 0x5; + } else if (x >= 1.75f) { + bits = 0x4; + } else if (x > 1.25f) { + bits = 0x3; + } else if (x >= 0.75f) { + bits = 0x2; + } else if (x > 0.25f) { + bits = 0x1; + } else { + bits = 0x0; + } + bits |= sign_bit; + } + + operator float() { + return FP4_LUT[bits]; + } + + uint8_t bits; +}; diff --git a/mlx/backend/metal/kernels/fp4_quantized.metal b/mlx/backend/metal/kernels/fp4_quantized.metal deleted file mode 100644 index 6b2daf88c..000000000 --- a/mlx/backend/metal/kernels/fp4_quantized.metal +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright © 2025 Apple Inc. - -// clang-format off -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" -#include "mlx/backend/metal/kernels/quantized_utils.h" -#include "mlx/backend/metal/kernels/fp4_quantized.h" - -#define instantiate_quantized(name, type) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4", \ - name, \ - type, \ - 32, \ - uint8_t) - -#define instantiate_quantized_batched(name, type, batched) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_batch_" #batched, \ - name, \ - type, \ - 32, \ - uint8_t, \ - batched) - -#define instantiate_quantized_aligned(name, type, aligned) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_alN_" #aligned, \ - name, \ - type, \ - 32, \ - uint8_t, \ - aligned) - -#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ - name, \ - type, \ - 32, \ - uint8_t, \ - aligned, \ - batched) - -#define instantiate_quantized_quad(name, type, D, batched) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ - name, \ - type, \ - 32, \ - uint8_t, \ - D, \ - batched) - -#define instantiate_quantized_split_k(name, type, split_k) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_spk_" #split_k, \ - name, \ - type, \ - 32, \ - uint8_t, \ - split_k) - -#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ - instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ - func, \ - type, \ - 32, \ - uint8_t, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - transpose) - -#define instantiate_quantized_batched_wrap(name, type) \ - instantiate_quantized_batched(name, type, 1) \ - instantiate_quantized_batched(name, type, 0) - -#define instantiate_quantized_all_batched(type) \ - instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \ - instantiate_quantized_batched_wrap(mxfp4_qmv, type) \ - instantiate_quantized_batched_wrap(mxfp4_qvm, type) \ - instantiate_quantized_batched_wrap(mxfp4_qmm_n, type) - -#define instantiate_quantized_all_single(type) \ - instantiate_quantized(mxfp4_gather_qmv_fast, type) \ - instantiate_quantized(mxfp4_gather_qmv, type) \ - instantiate_quantized(mxfp4_gather_qvm, type) \ - instantiate_quantized(mxfp4_gather_qmm_n, type) - -#define instantiate_quantized_all_aligned(type) \ - instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \ - instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0) - -#define instantiate_quantized_all_quad(type) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0) - -#define instantiate_quantized_all_splitk(type) \ - instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \ - instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32) - -#define instantiate_quantized_all_rhs(type) \ - instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ - instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) - -#define instantiate_quantized_types(type) \ - instantiate_quantized_all_batched(type) \ - instantiate_quantized_all_quad(type) \ - instantiate_quantized_all_splitk(type) \ - instantiate_quantized_all_single(type) \ - instantiate_quantized_all_aligned(type) \ - instantiate_quantized_all_rhs(type) - -instantiate_quantized_types(float) -instantiate_quantized_types(bfloat16_t) -instantiate_quantized_types(float16_t) - // clang-format on diff --git a/mlx/backend/metal/kernels/fp8.h b/mlx/backend/metal/kernels/fp8.h new file mode 100644 index 000000000..4b1836a39 --- /dev/null +++ b/mlx/backend/metal/kernels/fp8.h @@ -0,0 +1,88 @@ +#pragma once + +inline float fp32_from_bits(uint32_t bits) { + return *(reinterpret_cast(&bits)); +} +inline float fp32_to_bits(float x) { + return *(reinterpret_cast(&x)); +} + +struct fp8_e4m3 { + template + fp8_e4m3(T f) { + // From PyTorch + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 + uint32_t fp8_max = 543 << 21; + uint32_t denorm_mask = 141 << 23; + uint32_t f_bits = fp32_to_bits(static_cast(f)); + uint32_t sign = f_bits & 0x80000000; + f_bits ^= sign; + if (f_bits >= fp8_max) { + // Default behavior saturates to min/max + bits = 0x7E; + } else { + if (f_bits < (121 << 23)) { + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + bits = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + f_bits += mant_odd; + bits = static_cast(f_bits >> 20); + } + } + bits |= static_cast(sign >> 24); + } + + operator float() { + // From PyTorch: + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 + uint32_t w = static_cast(bits) << 24; + uint32_t sign = w & 0x80000000; + uint32_t nonsign = w & 0x7FFFFFFF; + + uint32_t renorm_shift = metal::clz(nonsign); + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + + int32_t inf_nan_mask = + (static_cast(nonsign + 0x01000000) >> 8) & 0x7F800000; + int32_t zero_mask = static_cast(nonsign - 1) >> 31; + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); + } + + uint8_t bits; +}; + +struct fp8_e8m0 { + fp8_e8m0(float x) { + if (!metal::isfinite(x)) { + bits = 0xFF; + return; + } + if (x < 0.0f) { + bits = 0x00; + return; + } + float le = metal::log2(x); + int n = int(metal::round(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + bits = static_cast(n + 127); + } + + operator float() { + if (bits == 0xFF) { + return metal::numeric_limits::quiet_NaN(); + } + return metal::ldexp(1.0f, static_cast(bits) - 127); + } + + uint8_t bits; +}; diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h similarity index 89% rename from mlx/backend/metal/kernels/fp4_quantized.h rename to mlx/backend/metal/kernels/fp_quantized.h index 0b22dc1e5..38e4c3a73 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -3,6 +3,9 @@ #include #include +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; @@ -59,28 +62,10 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { } } -constexpr constant static float MXFP4_LUT[16] = { - +0.0f, - +0.5f, - +1.0f, - +1.5f, - +2.0f, - +3.0f, - +4.0f, - +6.0f, - -0.0f, - -0.5f, - -1.0f, - -1.5f, - -2.0f, - -3.0f, - -4.0f, - -6.0f}; - template -void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { +void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { if (simd_gid == 0 && simd_lid < 16) { - lut[simd_lid] = static_cast(MXFP4_LUT[simd_lid]); + lut[simd_lid] = static_cast(FP4_LUT[simd_lid]); } threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -155,8 +140,7 @@ template < short dst_ld, short reduction_dim, short tgp_size, - short group_size, - typename S> + short group_size> struct QuantizedBlockLoader { static_assert( BCOLS <= group_size, @@ -183,12 +167,12 @@ struct QuantizedBlockLoader { threadgroup T* dst; const device uint8_t* src; - const device S* scales; + const device uint8_t* scales; threadgroup T* lut; QuantizedBlockLoader( const device uint8_t* src_, - const device S* scales_, + const device uint8_t* scales_, const int src_ld_, threadgroup T* dst_, threadgroup T* lut_, @@ -208,7 +192,7 @@ struct QuantizedBlockLoader { bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), lut(lut_) { - load_mxfp4_lut(lut, simd_group_id, simd_lane_id); + load_fp4_lut(lut, simd_group_id, simd_lane_id); } void load_unsafe() const { @@ -270,10 +254,10 @@ struct QuantizedBlockLoader { } }; -template -METAL_FUNC void mxfp4_qmv_quad_impl( +template +METAL_FUNC void fp_qmv_quad_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, constant int& in_vec_size, @@ -295,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl( thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; @@ -311,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl( for (int row = 0; row < results_per_quadgroup; row++) { auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); - const device S* sl = scales + row * in_vec_size_g * quads_per_simd; + const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd; U s = dequantize_scale(sl[0]); if (row * quads_per_simd + out_row < out_vec_size) { @@ -327,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl( } } -template -METAL_FUNC void mxfp4_qmv_fast_impl( +template +METAL_FUNC void fp_qmv_fast_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -353,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl( typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -390,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl( } } -template -METAL_FUNC void mxfp4_qmv_impl( +template +METAL_FUNC void fp_qmv_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -418,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl( thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -448,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - S s = sl[0]; + uint8_t s = sl[0]; result[row] += qdot(wl, x_thread, s, lut); } @@ -529,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl( } } -template -METAL_FUNC void mxfp4_qvm_impl( +template +METAL_FUNC void fp_qvm_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const int in_vec_size, @@ -561,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl( thread U scale = 0; thread U x_local = 0; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; @@ -633,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl( template < typename T, const int group_size, - typename S, + const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> -METAL_FUNC void mxfp4_qmm_t_impl( +METAL_FUNC void fp_qmm_t_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Xs, @@ -677,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl( BK_padded, 1, WM * WN * SIMD_SIZE, - group_size, - S>; + group_size>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; @@ -759,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl( template < typename T, const int group_size, - typename S, + const int bits, const int BM = 32, const int BK = 32, const int BN = 32> -METAL_FUNC void mxfp4_qmm_n_impl( +METAL_FUNC void fp_qmm_n_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Xs, @@ -803,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl( BN_padded, 0, WM * WN * SIMD_SIZE, - group_size, - S>; + group_size>; auto wl = (const device uint8_t*)w; @@ -891,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl( } } -template +template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, - const device S*& scales, + const device uint8_t*& scales, device T*& y, int output_stride, const constant int& x_batch_ndims, @@ -926,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, - const device S*& scales, + const device uint8_t*& scales, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, @@ -976,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template -[[kernel]] void mxfp4_qmv_quad( +template +[[kernel]] void fp_qmv_quad( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1014,7 +996,7 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_quad_impl( + fp_qmv_quad_impl( w, scales, x, @@ -1029,10 +1011,10 @@ template lut); } -template -[[kernel]] void mxfp4_qmv_fast( +template +[[kernel]] void fp_qmv_fast( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1065,14 +1047,14 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_fast_impl( + fp_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_qmv( +template +[[kernel]] void fp_qmv( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1105,14 +1087,14 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_impl( + fp_qmv_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_qvm( +template +[[kernel]] void fp_qvm( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1145,14 +1127,14 @@ template tid); } threadgroup float lut[16]; - mxfp4_qvm_impl( + fp_qvm_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_qvm_split_k( +template +[[kernel]] void fp_qvm_split_k( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1189,7 +1171,7 @@ template tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; threadgroup float lut[16]; - mxfp4_qvm_impl( + fp_qvm_impl( w, scales, x, @@ -1205,15 +1187,15 @@ template template < typename T, const int group_size, - typename S, + const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_qmm_t( +[[kernel]] void fp_qmm_t( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& K, @@ -1254,21 +1236,21 @@ template < s_strides, tid); } - mxfp4_qmm_t_impl( + fp_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - typename S, + const int bits, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_qmm_n( +[[kernel]] void fp_qmm_n( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& K, @@ -1311,14 +1293,14 @@ template < tid); } - mxfp4_qmm_n_impl( + fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_gather_qmv_fast( +template +[[kernel]] void fp_gather_qmv_fast( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1361,14 +1343,14 @@ template s_strides, tid); threadgroup float lut[16]; - mxfp4_qmv_fast_impl( + fp_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_gather_qmv( +template +[[kernel]] void fp_gather_qmv( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1411,14 +1393,14 @@ template s_strides, tid); threadgroup float lut[16]; - mxfp4_qmv_impl( + fp_qmv_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_gather_qvm( +template +[[kernel]] void fp_gather_qvm( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1461,21 +1443,21 @@ template s_strides, tid); threadgroup float lut[16]; - mxfp4_qvm_impl( + fp_qvm_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - typename S, + const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_gather_qmm_t( +[[kernel]] void fp_gather_qmm_t( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1526,20 +1508,20 @@ template < w_strides, s_strides, tid); - mxfp4_qmm_t_impl( + fp_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - typename S, + const int bits, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_gather_qmm_n( +[[kernel]] void fp_gather_qmm_n( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1591,24 +1573,24 @@ template < w_strides, s_strides, tid); - mxfp4_qmm_n_impl( + fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, int group_size, - typename S, + int bits, int BM, int BN, int BK, int WM, int WN, bool transpose> -[[kernel]] void mxfp4_gather_qmm_rhs( +[[kernel]] void fp_gather_qmm_rhs( const device T* x, const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device uint32_t* indices, device T* y, const constant int& M, @@ -1644,8 +1626,7 @@ template < transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, - group_size, - S>; + group_size>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; @@ -1789,3 +1770,100 @@ template < } } } + +template +struct Quantize { + uint8_t operator()(float x) { + if (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + float operator()(uint8_t x) { + if (bits == 8) { + return float(*(thread fp8_e4m3*)(&x)); + } else { + return float(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +[[kernel]] void fp_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device uint8_t* scales [[buffer(2)]], + uint2 tidx [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + size_t index = tidx.x + grid_dim.x * size_t(tidx.y); + + float scale; + float w_thread = w[index]; + if (use_mx_scale) { + scale = simd_max(abs(w_thread)); + } else { + float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); + float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); + scale = tidx.x < 16 ? w_max_l : w_max_r; + } + scale /= bits == 4 ? 6.0f : 448.0f; + + using ScaleType = metal::conditional_t; + auto s = ScaleType(scale); + uint8_t q_scale = s.bits; + scale = float(s); + + // Write out the scales and biases + size_t gindex = index / group_size; + if (index % group_size == 0) { + scales[gindex] = q_scale; + } + + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + if (bits == 4) { + uint8_t sval = simd_shuffle_down(output, 1); + output |= sval << bits; + } + constexpr int pack_factor = bits == 8 ? 1 : 2; + if (index % pack_factor == 0) { + out[index / pack_factor] = output; + } +} + +template +[[kernel]] void fp_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + constexpr int pack_factor = bits == 8 ? 1 : 2; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * pack_factor; + size_t gindex = oindex / group_size; + + out += oindex; + + using ScaleType = metal::conditional_t; + auto q_scale = ((device ScaleType*)(scales))[gindex]; + auto scale = float(q_scale); + + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} diff --git a/mlx/backend/metal/kernels/fp_quantized.metal b/mlx/backend/metal/kernels/fp_quantized.metal new file mode 100644 index 000000000..091174d91 --- /dev/null +++ b/mlx/backend/metal/kernels/fp_quantized.metal @@ -0,0 +1,147 @@ +// Copyright © 2025 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" +#include "mlx/backend/metal/kernels/fp_quantized.h" + +#define instantiate_quantized(mode, name, type) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4", \ + fp_ ## name, \ + type, \ + 32, \ + 4) + +#define instantiate_quantized_batched(mode, name, type, batched) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + batched) + +#define instantiate_quantized_aligned(mode, name, type, aligned) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + aligned) + +#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + aligned, \ + batched) + +#define instantiate_quantized_quad(mode, name, type, D, batched) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + D, \ + batched) + +#define instantiate_quantized_split_k(mode, name, type, split_k) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + split_k) + +#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + 32, \ + 4, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + +#define instantiate_quantized_batched_wrap(mode, name, type) \ + instantiate_quantized_batched(mode, name, type, 1) \ + instantiate_quantized_batched(mode, name, type, 0) + +#define instantiate_quantized_all_batched(type) \ + instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \ + instantiate_quantized_batched_wrap(mxfp4, qmv, type) \ + instantiate_quantized_batched_wrap(mxfp4, qvm, type) \ + instantiate_quantized_batched_wrap(mxfp4, qmm_n, type) + +#define instantiate_quantized_all_single(type) \ + instantiate_quantized(mxfp4, gather_qmv_fast, type) \ + instantiate_quantized(mxfp4, gather_qmv, type) \ + instantiate_quantized(mxfp4, gather_qvm, type) \ + instantiate_quantized(mxfp4, gather_qmm_n, type) + +#define instantiate_quantized_all_aligned(type) \ + instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \ + instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0) + +#define instantiate_quantized_all_quad(type) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0) + +#define instantiate_quantized_all_splitk(type) \ + instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \ + instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32) + +#define instantiate_quantized_all_rhs(type) \ + instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) + +#define instantiate_quantize_dequantize(type, mode, group_size, bits) \ + instantiate_kernel( \ + #mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ + fp_quantize, \ + type, \ + group_size, \ + bits) \ + instantiate_kernel( \ + #mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ + fp_dequantize, \ + type, \ + group_size, \ + bits) + +#define instantiate_quantize_dequantize_modes(type) \ + instantiate_quantize_dequantize(type, mxfp4, 32, 4) \ + instantiate_quantize_dequantize(type, nvfp4, 16, 4) \ + instantiate_quantize_dequantize(type, mxfp8, 32, 8) + +#define instantiate_quantized_types(type) \ + instantiate_quantized_all_batched(type) \ + instantiate_quantized_all_quad(type) \ + instantiate_quantized_all_splitk(type) \ + instantiate_quantized_all_single(type) \ + instantiate_quantized_all_aligned(type) \ + instantiate_quantized_all_rhs(type) \ + instantiate_quantize_dequantize_modes(type) + +instantiate_quantized_types(float) +instantiate_quantized_types(bfloat16_t) +instantiate_quantized_types(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 423c07f66..327bb5a94 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -8,6 +8,7 @@ #include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" +#include "mlx/backend/metal/kernels/fp8.h" namespace { constant float inf = metal::numeric_limits::infinity(); @@ -439,63 +440,15 @@ complex64_t ArcTan::operator()(complex64_t x) { return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); }; -inline float fp32_from_bits(uint32_t bits) { - return *(reinterpret_cast(&bits)); -} -inline float fp32_to_bits(float x) { - return *(reinterpret_cast(&x)); -} - struct ToFP8 { template uint8_t operator()(T f) { - // From PyTorch - // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 - uint32_t fp8_max = 1087 << 20; - uint32_t denorm_mask = 141 << 23; - uint32_t f_bits = fp32_to_bits(static_cast(f)); - uint8_t result = 0u; - uint32_t sign = f_bits & 0x80000000; - f_bits ^= sign; - if (f_bits >= fp8_max) { - // Default behavior saturates to min/max - result = 0x7E; - } else { - if (f_bits < (121 << 23)) { - f_bits = - fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); - result = static_cast(f_bits - denorm_mask); - } else { - // resulting mantissa is odd - uint8_t mant_odd = (f_bits >> 20) & 1; - f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; - f_bits += mant_odd; - result = static_cast(f_bits >> 20); - } - } - result |= static_cast(sign >> 24); - return result; + return fp8_e4m3(f).bits; } }; struct FromFP8 { float operator()(uint8_t x) { - // From PyTorch: - // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 - uint32_t w = static_cast(x) << 24; - uint32_t sign = w & 0x80000000; - uint32_t nonsign = w & 0x7FFFFFFF; - - uint32_t renorm_shift = metal::clz(nonsign); - renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; - - int32_t inf_nan_mask = - (static_cast(nonsign + 0x01000000) >> 8) & 0x7F800000; - int32_t zero_mask = static_cast(nonsign - 1) >> 31; - uint32_t result = sign | - ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | - inf_nan_mask) & - ~zero_mask); - return fp32_from_bits(result); + return float(*(thread fp8_e4m3*)(&x)); } }; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 328669a92..e03e5dca2 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped( int bits, Args... args) { std::string template_def; - auto fname = mode + "_" + func; - if (mode == "affine") { - template_def = get_template_definition( - name, fname, type, group_size, bits, std::forward(args)...); - } else { - template_def = get_template_definition( - name, fname, type, group_size, "uint8_t", std::forward(args)...); - } + std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func; + template_def = get_template_definition( + name, fname, type, group_size, bits, std::forward(args)...); return get_quantized_kernel(d, name, template_def, mode); } @@ -1045,26 +1040,31 @@ void fast::Quantize::eval_gpu( compute_encoder.set_input_array(w, 0); if (dequantize_) { auto scales = ensure_row_contiguous(inputs[1], d, s); - auto biases = ensure_row_contiguous(inputs[2], d, s); compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); compute_encoder.set_output_array(out, 3); + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous(inputs[2], d, s); + compute_encoder.set_input_array(biases, 2); + } } else { auto& scales = outputs[1]; - auto& biases = outputs[2]; scales.set_data(allocator::malloc(scales.nbytes())); - biases.set_data(allocator::malloc(biases.nbytes())); compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(scales, 2); - compute_encoder.set_output_array(biases, 3); + if (mode_ == QuantizationMode::Affine) { + auto& biases = outputs[2]; + biases.set_data(allocator::malloc(biases.nbytes())); + compute_encoder.set_output_array(biases, 3); + } } auto type_string = dequantize_ ? get_type_string(out.dtype()) : get_type_string(w_pre.dtype()); + auto mode = quantization_mode_to_string(mode_); std::string kname; concatenate( kname, - dequantize_ ? "affine_dequantize" : "affine_quantize", + mode + (dequantize_ ? "_dequantize" : "_quantize"), "_", type_string, "_gs_", @@ -1075,7 +1075,7 @@ void fast::Quantize::eval_gpu( d, kname, dequantize_ ? "dequantize" : "quantize", - "affine", + mode, type_string, group_size_, bits_); @@ -1088,7 +1088,8 @@ void fast::Quantize::eval_gpu( int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 : bits_ == 6 ? 4 : 8 / bits_; - int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; + int per_thread = + dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1); size_t nthreads = dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index d2b4b5611..649e554e6 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -273,7 +273,7 @@ class ConvertFP8 : public Primitive { }; bool is_equivalent(const Primitive& other) const override; - DEFINE_INPUT_OUTPUT_SHAPE() + DEFINE_INPUT_OUTPUT_SHAPE(); private: bool to_fp8_; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 879ef4fd5..271462d56 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2,7 +2,6 @@ // Required for using M_PI in MSVC. #define _USE_MATH_DEFINES - #include #include #include @@ -4017,21 +4016,50 @@ array conv_general( {in, wt}); } -void validate_mode(std::string_view tag, const std::string& mode) { - if (mode != "affine" && mode != "mxfp4") { - std::ostringstream msg; - msg << "[" << tag << "] Invalid quantization mode '" << mode << "'."; - throw std::invalid_argument(msg.str()); +std::pair quantization_params_from_mode( + QuantizationMode mode, + std::optional group_size_, + std::optional bits_) { + int default_group_size; + int default_bits; + switch (mode) { + case QuantizationMode::Affine: + default_group_size = 64; + default_bits = 4; + break; + case QuantizationMode::Nvfp4: + default_group_size = 16; + default_bits = 4; + break; + case QuantizationMode::Mxfp4: + default_group_size = 32; + default_bits = 4; + break; + case QuantizationMode::Mxfp8: + default_group_size = 32; + default_bits = 8; + break; } + return { + group_size_.has_value() ? *group_size_ : default_group_size, + bits_.has_value() ? *bits_ : default_bits}; } -Dtype validate_mode_with_type( +std::pair validate_mode_with_type( std::string_view tag, const array& scales, const std::optional& biases, + const std::optional out_type, const std::string& mode) { - validate_mode(tag, mode); - if (mode == "affine") { + auto qmode = string_to_quantization_mode(mode, tag); + if (out_type.has_value() && !issubdtype(*out_type, floating)) { + std::ostringstream msg; + msg << "[" << tag << "] Only real floating types are supported but " + << "output dtype == " << *out_type << "."; + throw std::invalid_argument(msg.str()); + } + + if (qmode == QuantizationMode::Affine) { if (!biases) { std::ostringstream msg; msg << "[" << tag << "] Biases must be provided for affine quantization."; @@ -4045,7 +4073,11 @@ Dtype validate_mode_with_type( << " and biases.dtype() == " << biases->dtype() << "."; throw std::invalid_argument(msg.str()); } - return dtype; + if (out_type.has_value()) { + return {*out_type, qmode}; + } else { + return {dtype, qmode}; + } } if (biases) { std::ostringstream msg; @@ -4053,7 +4085,11 @@ Dtype validate_mode_with_type( << "'."; throw std::invalid_argument(msg.str()); } - return bfloat16; + if (out_type.has_value()) { + return {*out_type, qmode}; + } else { + return {bfloat16, qmode}; + } } array quantized_matmul( @@ -4062,17 +4098,24 @@ array quantized_matmul( array scales, std::optional biases /* = std::nullopt */, bool transpose /* = true */, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { + auto [dtype, qmode] = validate_mode_with_type( + "quantized_matmul", scales, biases, std::nullopt, mode); + + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - auto dtype = - validate_mode_with_type("quantized_matmul", scales, biases, mode); - dtype = promote_types(x.dtype(), dtype); + if (qmode == QuantizationMode::Affine) { + dtype = promote_types(x.dtype(), dtype); + } else { + dtype = x.dtype(); + } if (!issubdtype(dtype, floating)) { std::ostringstream msg; @@ -4081,7 +4124,7 @@ array quantized_matmul( throw std::invalid_argument(msg.str()); } std::vector inputs; - if (mode == "affine") { + if (qmode == QuantizationMode::Affine) { inputs = { astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; } else { @@ -4098,11 +4141,7 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), - group_size, - bits, - string_to_quantization_mode(mode), - transpose), + to_stream(s), group_size, bits, qmode, transpose), std::move(inputs)); } @@ -4216,13 +4255,110 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { {w}); } +std::vector fp_quantize( + const array& w, + int group_size, + int bits, + QuantizationMode mode, + Stream s) { + int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; + int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; + if (group_size != expected_gs) { + std::ostringstream msg; + msg << "[quantize] " << quantization_mode_to_string(mode) + << " quantization requires group size " << expected_gs << " but got " + << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != expected_bits) { + std::ostringstream msg; + msg << "[quantize] " << quantization_mode_to_string(mode) + << " quantization requires bits to be " << expected_bits << " but got " + << bits << "."; + throw std::invalid_argument(msg.str()); + } + auto fallback = [bits = bits, group_size = group_size, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; + float maxval = (bits == 4) ? 6.0f : 448.0f; + auto new_shape = w.shape(); + new_shape.back() = -1; + auto wq = reshape(w, {-1, group_size}, s); + auto scales = + divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s); + if (group_size == 16) { + // convert to e4m3 + scales = to_fp8(scales, s); + wq = divide(wq, from_fp8(scales, w.dtype(), s), s); + } else { + // convert to e8m0 + auto z = array(0, scales.dtype()); + scales = where( + equal(scales, z, s), + z, + astype(round(log2(scales, s), s), int32, s), + s); + + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + } + if (bits == 4) { + auto lut = array({ + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }); + lut = astype(lut, w.dtype(), s); + wq = argmin( + abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); + auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); + wq = reshape(wq, {-1, 4, 8}, s); + wq = sum(multiply(wq, shifts, s), -1, false, s); + } else { + wq = view(to_fp8(wq, s), uint32, s); + } + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); + return {std::move(wq), std::move(scales)}; + }; + + if (s.device == Device::gpu) { + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) * bits / 32; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size; + return array::make_arrays( + {std::move(wq_shape), std::move(sshape)}, + {uint32, uint8}, + std::make_shared( + s, fallback, group_size, bits, mode, false), + {w}); + } + return fallback({w}); +} + std::vector quantize( const array& w, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { - validate_mode("quantize", mode); + auto qmode = string_to_quantization_mode(mode, "quantize"); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); if (!issubdtype(w.dtype(), floating)) { std::ostringstream msg; msg << "[quantize] Only real floating types can be quantized " @@ -4246,57 +4382,10 @@ std::vector quantize( throw std::invalid_argument(msg.str()); } - if (mode == "affine") { + if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); } else { - if (group_size != 32) { - std::ostringstream msg; - msg << "[quantize] mxfp4 quantization requires group size 32 " - << "but got " << group_size << "."; - throw std::invalid_argument(msg.str()); - } - if (bits != 4) { - std::ostringstream msg; - msg << "[quantize] mxfp4 quantization requires bits to be 4 " - << "but got " << bits << "."; - throw std::invalid_argument(msg.str()); - } - - auto lut = array({ - +0.0f, - +0.5f, - +1.0f, - +1.5f, - +2.0f, - +3.0f, - +4.0f, - +6.0f, - -0.0f, - -0.5f, - -1.0f, - -1.5f, - -2.0f, - -3.0f, - -4.0f, - -6.0f, - }); - lut = astype(lut, w.dtype(), s); - - auto new_shape = w.shape(); - new_shape.back() = -1; - auto wq = reshape(w, {-1, group_size}, s); - auto scales = - divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s); - scales = astype(log2(scales, s), int32, s); - wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); - scales = astype(add(scales, array(127, int32), s), uint8, s); - wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); - auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); - wq = reshape(wq, {-1, group_size / 8, 8}, s); - wq = sum(multiply(wq, shifts, s), -1, false, s); - wq = reshape(wq, new_shape, s); - scales = reshape(scales, new_shape, s); - return {std::move(wq), std::move(scales)}; + return fp_quantize(w, group_size, bits, qmode, to_stream(s)); } } @@ -4307,16 +4396,13 @@ array affine_dequantize( int group_size, int bits, StreamOrDevice s_) { - if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - auto wshape = w.shape(); auto sshape = scales.shape(); auto bshape = biases.shape(); + if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) { + throw std::invalid_argument( + "[dequantize] Shape of scales and biases does not match the matrix"); + } wshape.back() = -1; sshape.back() = -1; bshape.back() = -1; @@ -4397,15 +4483,132 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } +array fp_dequantize( + const array& w, + const array& scales, + int group_size, + int bits, + Dtype out_type, + QuantizationMode mode, + Stream s) { + int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; + int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; + if (group_size != expected_gs) { + std::ostringstream msg; + msg << "[dequantize] " << quantization_mode_to_string(mode) + << " quantization requires group size " << expected_gs << " but got " + << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != expected_bits) { + std::ostringstream msg; + msg << "[dequantize] " << quantization_mode_to_string(mode) + << " quantization requires bits to be " << expected_bits << " but got " + << bits << "."; + throw std::invalid_argument(msg.str()); + } + + auto wshape = w.shape(); + auto sshape = scales.shape(); + if (wshape.size() != sshape.size()) { + throw std::invalid_argument( + "[dequantize] Shape of scales does not match the matrix"); + } + + wshape.back() = -1; + sshape.back() = -1; + + if (wshape != sshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales does not match the matrix"); + } + + // Packing into uint32 + int out_size = w.shape(-1) * 32 / bits; + + if (out_size != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales of shape " << scales.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto fallback = + [wshape = std::move(wshape), + sshape = std::move(sshape), + group_size, + bits, + out_type, + s](const std::vector& inputs) mutable -> std::vector { + auto out = inputs[0]; + auto scales = inputs[1]; + if (bits == 4) { + auto lut = array( + { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }, + out_type); + out = view(reshape(out, {-1, 4}, s), int8, s); + auto idx_lo = bitwise_and(out, array(0x0F, int8), s); + auto idx_hi = right_shift(out, array(4, int8), s); + auto lo = gather(lut, idx_lo, 0, {1}, s); + auto hi = gather(lut, idx_hi, 0, {1}, s); + out = concatenate({lo, hi}, -1, s); + } else { + out = from_fp8(view(out, uint8, s), out_type, s); + } + out = reshape(out, {-1, group_size}, s); + scales = reshape(scales, {-1, 1}, s); + if (group_size == 16) { + scales = from_fp8(scales, out_type, s); + } else { + scales = subtract(astype(scales, out_type, s), array(127, out_type), s); + scales = power(array(2.0f, out_type), scales, s); + } + return {reshape(multiply(out, scales, s), wshape, s)}; + }; + if (s.device == Device::gpu) { + auto out_shape = w.shape(); + out_shape.back() = out_size; + return array( + std::move(out_shape), + out_type, + std::make_shared( + s, fallback, group_size, bits, mode, true), + {w, scales}); + } + return fallback({w, scales})[0]; +} + array dequantize( const array& w, const array& scales, const std::optional& biases /* = std::nullopt */, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, + std::optional dtype /* = std::nullopt */, StreamOrDevice s /* = {} */) { - validate_mode_with_type("dequantize", scales, biases, mode); + auto [out_type, qmode] = + validate_mode_with_type("dequantize", scales, biases, dtype, mode); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); if (bits <= 0) { std::ostringstream msg; msg << "[dequantize] Invalid value for bits: " << bits; @@ -4420,89 +4623,21 @@ array dequantize( throw std::invalid_argument( "[dequantize] The matrix should be given as a uint32"); } + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[dequantize] The matrix to be dequantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } - if (mode == "affine") { - return affine_dequantize(w, scales, *biases, group_size, bits, s); + if (qmode == QuantizationMode::Affine) { + return astype( + affine_dequantize(w, scales, *biases, group_size, bits, s), + out_type, + s); } else { - if (group_size != 32) { - std::ostringstream msg; - msg << "[dequantize] mxfp4 quantization requires group size 32 " - << "but got " << group_size << "."; - throw std::invalid_argument(msg.str()); - } - if (bits != 4) { - std::ostringstream msg; - msg << "[dequantize] mxfp4 quantization requires bits to be 4 " - << "but got " << bits << "."; - throw std::invalid_argument(msg.str()); - } - - if (w.ndim() < 2 || scales.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - auto wshape = w.shape(); - auto sshape = scales.shape(); - wshape.back() = -1; - sshape.back() = -1; - - if (wshape != sshape) { - throw std::invalid_argument( - "[dequantize] Shape of scales does not match the matrix"); - } - - if (w.dtype() != uint32) { - throw std::invalid_argument( - "[dequantize] The matrix should be given as a uint32"); - } - - // Packing into uint32 - int out_size = w.shape(-1) * 32 / bits; - - if (out_size != scales.shape(-1) * group_size) { - std::ostringstream msg; - msg << "[dequantize] Shape of scales does not match the matrix " - << "given the quantization parameters. Provided matrix of shape " - << w.shape() << " and scales of shape " << scales.shape() << "."; - throw std::invalid_argument(msg.str()); - } - - auto dtype = bfloat16; - auto lut = array( - { - +0.0f, - +0.5f, - +1.0f, - +1.5f, - +2.0f, - +3.0f, - +4.0f, - +6.0f, - -0.0f, - -0.5f, - -1.0f, - -1.5f, - -2.0f, - -3.0f, - -4.0f, - -6.0f, - }, - dtype); - - auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s); - - auto idx_lo = bitwise_and(what, array(0x0F, int8), s); - auto idx_hi = right_shift(what, array(4, int8), s); - auto lo = gather(lut, idx_lo, 0, {1}, s); - auto hi = gather(lut, idx_hi, 0, {1}, s); - what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s); - auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s); - exponent = reshape(exponent, {-1, 1}, s); - return reshape( - multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s); + return fp_dequantize( + w, scales, group_size, bits, out_type, qmode, to_stream(s)); } } @@ -4548,21 +4683,27 @@ array gather_qmm( std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, bool transpose /* = true */, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, mode, s); + x, w, scales, biases, transpose, group_size_, bits_, mode, s); } + auto [out_type, qmode] = + validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "gather_qmm", x, w, scales, biases, transpose, group_size, bits); - - auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode); - out_type = promote_types(x.dtype(), out_type); + if (qmode == QuantizationMode::Affine) { + out_type = promote_types(x.dtype(), out_type); + } else { + out_type = x.dtype(); + } if (!issubdtype(out_type, floating)) { std::ostringstream msg; @@ -4601,7 +4742,7 @@ array gather_qmm( out_shape.push_back(x.shape(-2)); out_shape.push_back(w_outer_dims); std::vector inputs; - if (mode == "affine") { + if (qmode == QuantizationMode::Affine) { inputs = { astype(x, out_type, s), std::move(w), @@ -4624,7 +4765,7 @@ array gather_qmm( to_stream(s), group_size, bits, - string_to_quantization_mode(mode), + qmode, transpose, sorted_indices && !rhs_indices_, sorted_indices && !lhs_indices_), diff --git a/mlx/ops.h b/mlx/ops.h index 312caac6d..49c64e74f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1379,16 +1379,16 @@ array quantized_matmul( array scales, std::optional biases = std::nullopt, bool transpose = true, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ std::vector quantize( const array& w, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", StreamOrDevice s = {}); @@ -1397,9 +1397,10 @@ array dequantize( const array& w, const array& scales, const std::optional& biases = std::nullopt, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", + std::optional dtype = std::nullopt, StreamOrDevice s = {}); /** Convert an E4M3 float8 to the given floating point dtype. */ @@ -1417,8 +1418,8 @@ array gather_qmm( std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, bool transpose = true, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", bool sorted_indices = false, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0b335e765..976200ba3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3328,19 +3328,37 @@ std::pair, std::vector> Power::vmap( } std::string quantization_mode_to_string(QuantizationMode mode) { - if (mode == QuantizationMode::Affine) { - return "affine"; - } else { - return "mxfp4"; + switch (mode) { + case QuantizationMode::Affine: + return "affine"; + case QuantizationMode::Mxfp4: + return "mxfp4"; + case QuantizationMode::Mxfp8: + return "mxfp8"; + case QuantizationMode::Nvfp4: + default: + return "nvfp4"; } } -QuantizationMode string_to_quantization_mode(const std::string& mode) { +QuantizationMode string_to_quantization_mode( + const std::string& mode, + std::string_view tag /* = "" */) { if (mode == "affine") { return QuantizationMode::Affine; - } else { + } else if (mode == "mxfp4") { return QuantizationMode::Mxfp4; + } else if (mode == "mxfp8") { + return QuantizationMode::Mxfp8; + } else if (mode == "nvfp4") { + return QuantizationMode::Nvfp4; } + std::string msg; + if (!tag.empty()) { + msg += "[" + std::string(tag) + "]"; + } + msg += " Invalid quantization mode '" + mode + "'."; + throw std::invalid_argument(msg); } std::pair, std::vector> QuantizedMatmul::vmap( @@ -3404,6 +3422,7 @@ std::vector QuantizedMatmul::vjp( group_size_, bits_, quantization_mode_to_string(mode_), + std::nullopt, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); @@ -3558,6 +3577,7 @@ std::vector GatherQMM::vjp( group_size_, bits_, quantization_mode_to_string(mode_), + std::nullopt, stream()), -1, {-1, group_size_}, diff --git a/mlx/primitives.h b/mlx/primitives.h index 2a843a0e4..a1ad2425c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -151,10 +151,12 @@ class UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; -enum class QuantizationMode { Affine, Mxfp4 }; +enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; std::string quantization_mode_to_string(QuantizationMode mode); -QuantizationMode string_to_quantization_mode(const std::string& mode); +QuantizationMode string_to_quantization_mode( + const std::string& mode, + std::string_view error_tag = ""); class Abs : public UnaryPrimitive { public: diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2e364db76..9816837ba 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4194,13 +4194,13 @@ void init_ops(nb::module_& m) { "scales"_a, "biases"_a = nb::none(), "transpose"_a = true, - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4216,10 +4216,12 @@ void init_ops(nb::module_& m) { transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: @@ -4229,35 +4231,36 @@ void init_ops(nb::module_& m) { "quantize", &mx::quantize, nb::arg(), - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( - Quantize the matrix ``w`` using ``bits`` bits per element. + Quantize the array ``w``. Note, every ``group_size`` elements in a row of ``w`` are quantized - together. Hence, number of columns of ``w`` should be divisible by - ``group_size``. In particular, the rows of ``w`` are divided into groups of - size ``group_size`` which are quantized together. + together. Hence, the last dimension of ``w`` should be divisible by + ``group_size``. .. warning:: - ``quantize`` currently only supports 2D inputs with the second - dimension divisible by ``group_size`` + ``quantize`` only supports inputs with two or more dimensions with + the last dimension divisible by ``group_size`` - The supported quantization modes are ``"affine"`` and ``"mxfp4"``. They - are described in more detail below. + The supported quantization modes are ``"affine"``, ``"mxfp4"``, + ``"mxfp8"``, and ``"nvfp4"``. They are described in more detail below. Args: - w (array): Matrix to be quantized + w (array): Array to be quantized group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. Default: ``64``. + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. bits (int, optional): The number of bits occupied by each element of - ``w`` in the returned quantized matrix. Default: ``4``. + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: @@ -4268,7 +4271,22 @@ void init_ops(nb::module_& m) { * biases (array): The quantization biases (returned for ``mode=="affine"``). Notes: - The ``affine`` mode quantizes groups of :math:`g` consecutive + .. _quantize-modes: + + .. table:: Quantization modes + + ====== ====================== ========================== ============= ===== + mode group size bits scale type bias + ====== ====================== ========================== ============= ===== + affine 32, 64\ :sup:`*`, 128 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes + mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no + mxfp8 32\ :sup:`*` 4\ :sup:`*` e8m0 no + nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no + ====== ====================== ========================== ============= ===== + + :sup:`*` indicates the default value when unspecified. + + The ``"affine"`` mode quantizes groups of :math:`g` consecutive elements in a row of ``w``. For each group the quantized representation of each element :math:`\hat{w_i}` is computed as follows: @@ -4291,11 +4309,17 @@ void init_ops(nb::module_& m) { :math:`\beta` which are the returned ``scales`` and ``biases`` respectively. - The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements - of ``w``. For ``mxfp4`` the group size must be ``32``. The elements - are quantized to 4-bit precision floating-point values (E2M1) with a - shared 8-bit scale per group. Unlike ``affine`` quantization, - ``mxfp4`` does not have a bias value. More details on the format can + The ``"mxfp4"``, ``"mxfp8"``, and ``"nvfp4"`` modes similarly + quantize groups of :math:`g` elements of ``w``. For the ``"mx"`` + modes, the group size must be ``32``. For ``"nvfp4"`` the group + size must be 16. The elements are quantized to 4-bit or 8-bit + precision floating-point values: E2M1 for ``"fp4"`` and E4M3 for + ``"fp8"``. There is a shared 8-bit scale per group. The ``"mx"`` + modes us an E8M0 scale and the ``"nv"`` mode uses an E4M3 scale. + Unlike ``affine`` quantization, these modes does not have a bias + value. + + More details on the ``"mx"`` formats can be found in the `specification `_. )pbdoc"); m.def( @@ -4304,13 +4328,14 @@ void init_ops(nb::module_& m) { nb::arg(), "scales"_a, "biases"_a = nb::none(), - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", + "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. @@ -4320,16 +4345,23 @@ void init_ops(nb::module_& m) { biases (array, optional): The biases to use per ``group_size`` elements of ``w``. Default: ``None``. group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + dtype (Dtype, optional): The data type of the dequantized output. If + ``None`` the return type is inferred from the scales and biases + when possible and otherwise defaults to ``bfloat16``. + Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: array: The dequantized version of ``w`` Notes: - The currently supported quantization modes are ``"affine"`` and ``mxfp4``. + The currently supported quantization modes are ``"affine"``, + ``"mxfp4``, ``"mxfp8"``, and ``"nvfp4"``. For ``affine`` quantization, given the notation in :func:`quantize`, we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` @@ -4349,14 +4381,14 @@ void init_ops(nb::module_& m) { "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), "transpose"_a = true, - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4379,10 +4411,12 @@ void init_ops(nb::module_& m) { transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. sorted_indices (bool, optional): May allow a faster implementation if the passed indices are sorted. Default: ``False``. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 3a195ef54..e75106303 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -55,26 +55,109 @@ class TestQuantized(mlx_tests.MLXTestCase): # Invalid bits / group size with self.assertRaises(ValueError): - mx.quantize(w, bits=3, group_size=32, mode="mxfp4") + mx.quantize(w, bits=3, mode="mxfp4") with self.assertRaises(ValueError): - mx.quantize(w, group_size=64, bits=4, mode="mxfp4") + mx.quantize(w, group_size=64, mode="mxfp4") - w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") + w_q, scales = mx.quantize(w, mode="mxfp4") + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=3, mode="mxfp4") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") + mx.dequantize(w_q, scales, group_size=64, mode="mxfp4") + # Invalid output type with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") + mx.dequantize( + w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32 + ) - w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, mode="mxfp4") self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) # test quantize/dequantize 0s a = mx.zeros((256, 512)) - w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4") - w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + w_q, scales = mx.quantize(a, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, mode="mxfp4") + self.assertTrue(mx.all(w_hat == 0)) + + def test_mxfp8_quantize_dequantize(self): + w = 2 * mx.random.uniform(shape=(512, 32)) - 1 + w = w.astype(mx.bfloat16) + + # Invalid bits / group size + with self.assertRaises(ValueError): + mx.quantize(w, bits=3, mode="mxfp8") + + with self.assertRaises(ValueError): + mx.quantize(w, group_size=32, bits=7, mode="mxfp8") + w_q, scales = mx.quantize(w, group_size=32, mode="mxfp8") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, group_size=16, mode="mxfp8") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=4, mode="mxfp8") + + w_hat = mx.dequantize(w_q, scales, mode="mxfp8") + + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1)) + + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + w_q, scales = mx.quantize(a, mode="mxfp8") + w_hat = mx.dequantize(w_q, scales, mode="mxfp8") + self.assertTrue(mx.all(w_hat == 0)) + + def test_nvfp4_quantize_dequantize(self): + lut = mx.array( + [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + ) + w = lut[mx.random.randint(0, 16, shape=(128, 512))] + w = w.reshape(-1, 16) + w[:, 0] = 6 + w = (w + 3e-6).astype(mx.bfloat16) + + # Invalid bits / group size + with self.assertRaises(ValueError): + mx.quantize(w, bits=3, mode="nvfp4") + + with self.assertRaises(ValueError): + mx.quantize(w, group_size=64, mode="nvfp4") + + w_q, scales = mx.quantize(w, mode="nvfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=3, mode="nvfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, group_size=32, mode="nvfp4") + + w_hat = mx.dequantize(w_q, scales, mode="nvfp4") + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + w_q, scales = mx.quantize(a, mode="nvfp4") + w_hat = mx.dequantize(w_q, scales, mode="nvfp4") self.assertTrue(mx.all(w_hat == 0)) def test_qmm(self): @@ -662,6 +745,25 @@ class TestQuantized(mlx_tests.MLXTestCase): test_shape(32, 512, 32, transpose=False, **kwargs) test_shape(1, 512, 32, transpose=False, **kwargs) + def test_qmm_mxfp4_type(self): + indices = mx.array([[2], [0], [1]], dtype=mx.uint32) + + for t in [mx.bfloat16, mx.float16, mx.float32]: + x = mx.random.normal((32, 256)).astype(t) + + w = mx.random.normal((32, 256)) + wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32) + out = mx.quantized_matmul(x, wq, s, mode="mxfp4", group_size=32, bits=4) + self.assertEqual(out.dtype, t) + + w = mx.random.normal((4, 32, 256)) + wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32) + + out = mx.gather_qmm( + x, wq, s, rhs_indices=indices, mode="mxfp4", group_size=32, bits=4 + ) + self.assertEqual(out.dtype, t) + def test_gather_matmul_grad(self): def quantize(w, transpose=True, group_size=64, bits=4): qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)