From c961a3a5576c8cf0ba03349547ba6e4117ebf046 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 21 Oct 2025 11:49:58 -0700 Subject: [PATCH] fast cuda kernel for mx/nv quantization --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/quantized/affine_quantize.cu | 2 +- mlx/backend/cuda/quantized/fp_quantize.cu | 218 +++++++++++ mlx/backend/cuda/quantized/quantized.cpp | 19 +- mlx/backend/cuda/quantized/quantized.h | 18 + mlx/ops.cpp | 352 ++++++++++-------- mlx/primitives.cpp | 30 +- mlx/primitives.h | 6 +- python/tests/test_quantized.py | 7 +- 9 files changed, 492 insertions(+), 161 deletions(-) create mode 100644 mlx/backend/cuda/quantized/fp_quantize.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index eabee94f2..19cafb932 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -51,6 +51,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/quantized/affine_quantize.cu b/mlx/backend/cuda/quantized/affine_quantize.cu index 94e67d135..a64597a88 100644 --- a/mlx/backend/cuda/quantized/affine_quantize.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -306,7 +306,7 @@ void affine_dequantize( enc.set_input_array(scales); enc.set_input_array(biases); enc.set_output_array(w); - dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) { + dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) { dispatch_groups(group_size_, [&](auto group_size) { dispatch_bits(bits_, [&](auto bits) { using T = cuda_type_t; diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu new file mode 100644 index 000000000..70b125522 --- /dev/null +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -0,0 +1,218 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/quantized.h" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +namespace mlx::core { +namespace cu { + +template +struct Quantize { + __device__ uint8_t operator()(float x) { + if constexpr (bits == 8) { + return __nv_fp8_e4m3(x).__x; + } else { + return __nv_fp4_e2m1(x).__x; + } + } +}; + +template +struct Dequantize { + __device__ float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(__nv_fp8_e4m3*)(&x)); + } else { + return float(*(__nv_fp4_e2m1*)(&x)); + } + } +}; + +namespace cg = cooperative_groups; + +template +__global__ void +fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = + cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; + size_t out_index = tidx + grid_dim_x * size_t(tidy); + size_t in_index = out_index; + if (in_index >= size) { + return; + } + + float w_thread = w[in_index]; + + cg::greater max_op; + auto warp = cg::tiled_partition(cg::this_thread_block()); + + float scale = cg::reduce(warp, abs(w_thread), max_op); + scale /= bits == 4 ? 6.0f : 448.0f; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale); + uint8_t q_scale = s.__x; + scale = float(s); + + // Write out the scales + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = q_scale; + } + + uint8_t output = 0; + uint8_t val = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + output = val; + if (bits == 4) { + uint8_t sval = warp.shfl_down(val, 1); + output |= sval << bits; + } + constexpr int pack_factor = bits == 8 ? 1 : 2; + if (out_index % pack_factor == 0) { + out[out_index / pack_factor] = output; + } +} + +template +__global__ void +fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = + cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; + + constexpr int pack_factor = bits == 8 ? 1 : 2; + size_t offset = tidx + grid_dim_x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + using ScaleType = + std::conditional_t; + auto scale = float(((ScaleType*)(scales))[gindex]); + + out += oindex; + + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} + +} // namespace cu + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize; + if (bits == 8) { + kernel = cu::fp_quantize; + } else if (group_size == 16) { + kernel = cu::fp_quantize; + } + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = + get_launch_args(w.size(), w.shape(), w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + w.data(), + wq.data(), + scales.data(), + w.size()); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); +} + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s) { + constexpr int uint8_per_uint32 = 4; + int packs_per_int = 8 / bits; + + size_t size = w.size() / packs_per_int; + bool large = size > UINT_MAX; + auto grid_shape = w.shape(); + grid_shape.back() *= uint8_per_uint32; + + enc.set_input_array(wq); + enc.set_input_array(scales); + enc.set_output_array(w); + dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_dequantize; + if (bits == 8) { + kernel = cu::fp_dequantize; + } else if (group_size == 16) { + kernel = cu::fp_dequantize; + } + auto [num_blocks, block_dims] = + get_launch_args(size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + wq.data(), + scales.data(), + w.data(), + w.size()); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not dequantize to output with type float64."); + } + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 71c687d85..58710834f 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu( if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); - auto biases = ensure_row_contiguous(inputs[2], enc, s); auto& w = outputs[0]; w.set_data(allocator::malloc(w.nbytes())); - affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous(inputs[2], enc, s); + affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); + } else { + fp_dequantize(wq, scales, w, group_size_, bits_, enc, s); + } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); auto& wq = outputs[0]; auto& scales = outputs[1]; - auto& biases = outputs[2]; wq.set_data(allocator::malloc(wq.nbytes())); scales.set_data(allocator::malloc(scales.nbytes())); - biases.set_data(allocator::malloc(biases.nbytes())); - - affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); + if (mode_ == QuantizationMode::Affine) { + auto& biases = outputs[2]; + biases.set_data(allocator::malloc(biases.nbytes())); + affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); + } else { + fp_quantize(w, wq, scales, group_size_, bits_, enc, s); + } } } diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index ec6a08000..4f1980a9c 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -24,4 +24,22 @@ void affine_dequantize( cu::CommandEncoder& enc, const Stream& s); +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d99bf2f46..f6b0ee04c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4017,22 +4017,22 @@ array conv_general( {in, wt}); } -void validate_mode(std::string_view tag, const std::string& mode) { - if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" && - mode != "nvfp4") { - std::ostringstream msg; - msg << "[" << tag << "] Invalid quantization mode '" << mode << "'."; - throw std::invalid_argument(msg.str()); - } -} - -Dtype validate_mode_with_type( +std::pair validate_mode_with_type( std::string_view tag, const array& scales, const std::optional& biases, + const std::optional out_type, const std::string& mode) { - validate_mode(tag, mode); - if (mode == "affine") { + auto qmode = string_to_quantization_mode(mode, tag); + // TODO add tests for out_type + if (out_type.has_value() && !issubdtype(*out_type, floating)) { + std::ostringstream msg; + msg << "[" << tag << "] Only real floating types are supported but " + << "output dtype == " << *out_type << "."; + throw std::invalid_argument(msg.str()); + } + + if (qmode == QuantizationMode::Affine) { if (!biases) { std::ostringstream msg; msg << "[" << tag << "] Biases must be provided for affine quantization."; @@ -4046,7 +4046,11 @@ Dtype validate_mode_with_type( << " and biases.dtype() == " << biases->dtype() << "."; throw std::invalid_argument(msg.str()); } - return dtype; + if (out_type.has_value()) { + return {*out_type, qmode}; + } else { + return {dtype, qmode}; + } } if (biases) { std::ostringstream msg; @@ -4054,7 +4058,11 @@ Dtype validate_mode_with_type( << "'."; throw std::invalid_argument(msg.str()); } - return bfloat16; + if (out_type.has_value()) { + return {*out_type, qmode}; + } else { + return {bfloat16, qmode}; + } } array quantized_matmul( @@ -4071,8 +4079,8 @@ array quantized_matmul( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - auto dtype = - validate_mode_with_type("quantized_matmul", scales, biases, mode); + auto [dtype, qmode] = validate_mode_with_type( + "quantized_matmul", scales, biases, std::nullopt, mode); dtype = promote_types(x.dtype(), dtype); if (!issubdtype(dtype, floating)) { @@ -4082,7 +4090,7 @@ array quantized_matmul( throw std::invalid_argument(msg.str()); } std::vector inputs; - if (mode == "affine") { + if (qmode == QuantizationMode::Affine) { inputs = { astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; } else { @@ -4099,11 +4107,7 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), - group_size, - bits, - string_to_quantization_mode(mode), - transpose), + to_stream(s), group_size, bits, qmode, transpose), std::move(inputs)); } @@ -4217,53 +4221,31 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { {w}); } -std::vector quantize( +std::vector fp_quantize( const array& w, - int group_size /* = 64 */, - int bits /* = 4 */, - const std::string& mode /* = "affine" */, - StreamOrDevice s /* = {} */) { - validate_mode("quantize", mode); - if (!issubdtype(w.dtype(), floating)) { + int group_size, + int bits, + QuantizationMode mode, + Stream s) { + int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; + int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; + if (group_size != expected_gs) { std::ostringstream msg; - msg << "[quantize] Only real floating types can be quantized " - << "but w has type " << w.dtype() << "."; + msg << "[quantize] " << quantization_mode_to_string(mode) + << " quantization requires group size " << expected_gs << " but got " + << group_size << "."; throw std::invalid_argument(msg.str()); } - - if (w.ndim() < 2) { + if (bits != expected_bits) { std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; + msg << "[quantize] " << quantization_mode_to_string(mode) + << " quantization requires bits to be " << expected_bits << " but got " + << bits << "."; throw std::invalid_argument(msg.str()); } - - if ((w.shape(-1) % group_size) != 0) { - std::ostringstream msg; - msg << "[quantize] The last dimension of the matrix needs to be divisible by " - << "the quantization group size " << group_size - << ". However the provided " - << " matrix has shape " << w.shape(); - throw std::invalid_argument(msg.str()); - } - - if (mode == "affine") { - return affine_quantize(w, group_size, bits, s); - } else { - int expected_gs = (mode[0] == 'm') ? 32 : 16; - int expected_bits = (mode.back() == '8') ? 8 : 4; - if (group_size != expected_gs) { - std::ostringstream msg; - msg << "[quantize] " << mode << " quantization requires group size " - << expected_gs << " but got " << group_size << "."; - throw std::invalid_argument(msg.str()); - } - if (bits != expected_bits) { - std::ostringstream msg; - msg << "[quantize] " << mode << " quantization requires bits to be " - << expected_bits << " but got " << bits << "."; - throw std::invalid_argument(msg.str()); - } + auto fallback = [bits = bits, group_size = group_size, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; float maxval = (bits == 4) ? 6.0f : 448.0f; auto new_shape = w.shape(); new_shape.back() = -1; @@ -4314,6 +4296,57 @@ std::vector quantize( wq = reshape(wq, new_shape, s); scales = reshape(scales, new_shape, s); return {std::move(wq), std::move(scales)}; + }; + + if (s.device == Device::gpu) { + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) * bits / 32; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size; + return array::make_arrays( + {std::move(wq_shape), std::move(sshape)}, + {uint32, uint8}, + std::make_shared( + s, fallback, group_size, bits, mode, false), + {w}); + } + return fallback({w}); +} + +std::vector quantize( + const array& w, + int group_size /* = 64 */, + int bits /* = 4 */, + const std::string& mode /* = "affine" */, + StreamOrDevice s /* = {} */) { + auto qmode = string_to_quantization_mode(mode, "quantize"); + if (!issubdtype(w.dtype(), floating)) { + std::ostringstream msg; + msg << "[quantize] Only real floating types can be quantized " + << "but w has type " << w.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + if ((w.shape(-1) % group_size) != 0) { + std::ostringstream msg; + msg << "[quantize] The last dimension of the matrix needs to be divisible by " + << "the quantization group size " << group_size + << ". However the provided " + << " matrix has shape " << w.shape(); + throw std::invalid_argument(msg.str()); + } + + if (qmode == QuantizationMode::Affine) { + return affine_quantize(w, group_size, bits, s); + } else { + return fp_quantize(w, group_size, bits, qmode, to_stream(s)); } } @@ -4324,16 +4357,13 @@ array affine_dequantize( int group_size, int bits, StreamOrDevice s_) { - if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - auto wshape = w.shape(); auto sshape = scales.shape(); auto bshape = biases.shape(); + if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) { + throw std::invalid_argument( + "[dequantize] Shape of scales and biases does not match the matrix"); + } wshape.back() = -1; sshape.back() = -1; bshape.back() = -1; @@ -4414,88 +4444,66 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } -array dequantize( +array fp_dequantize( const array& w, const array& scales, - const std::optional& biases /* = std::nullopt */, - int group_size /* = 64 */, - int bits /* = 4 */, - const std::string& mode /* = "affine" */, - std::optional dtype /* = std::nullopt */, - StreamOrDevice s /* = {} */) { - validate_mode_with_type("dequantize", scales, biases, mode); - if (bits <= 0) { + int group_size, + int bits, + Dtype out_type, + QuantizationMode mode, + Stream s) { + int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; + int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; + if (group_size != expected_gs) { std::ostringstream msg; - msg << "[dequantize] Invalid value for bits: " << bits; + msg << "[dequantize] " << quantization_mode_to_string(mode) + << " quantization requires group size " << expected_gs << " but got " + << group_size << "."; throw std::invalid_argument(msg.str()); } - if (group_size <= 0) { + if (bits != expected_bits) { std::ostringstream msg; - msg << "[dequantize] Invalid value for group_size: " << group_size; + msg << "[dequantize] " << quantization_mode_to_string(mode) + << " quantization requires bits to be " << expected_bits << " but got " + << bits << "."; throw std::invalid_argument(msg.str()); } - if (w.dtype() != uint32) { + + auto wshape = w.shape(); + auto sshape = scales.shape(); + if (wshape.size() != sshape.size()) { throw std::invalid_argument( - "[dequantize] The matrix should be given as a uint32"); + "[dequantize] Shape of scales does not match the matrix"); } - if (mode == "affine") { - auto out = affine_dequantize(w, scales, *biases, group_size, bits, s); - if (dtype) { - out = astype(out, *dtype, s); - } - return out; - } else { - int expected_gs = (mode[0] == 'm') ? 32 : 16; - int expected_bits = (mode.back() == '8') ? 8 : 4; - if (group_size != expected_gs) { - std::ostringstream msg; - msg << "[quantize] " << mode << " quantization requires group size " - << expected_gs << " but got " << group_size << "."; - throw std::invalid_argument(msg.str()); - } - if (bits != expected_bits) { - std::ostringstream msg; - msg << "[quantize] " << mode << " quantization requires bits to be " - << expected_bits << " but got " << bits << "."; - throw std::invalid_argument(msg.str()); - } + wshape.back() = -1; + sshape.back() = -1; - if (w.ndim() < 2 || scales.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be dequantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } + if (wshape != sshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales does not match the matrix"); + } - auto wshape = w.shape(); - auto sshape = scales.shape(); - wshape.back() = -1; - sshape.back() = -1; + // Packing into uint32 + int out_size = w.shape(-1) * 32 / bits; - if (wshape != sshape) { - throw std::invalid_argument( - "[dequantize] Shape of scales does not match the matrix"); - } + if (out_size != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales of shape " << scales.shape() << "."; + throw std::invalid_argument(msg.str()); + } - if (w.dtype() != uint32) { - throw std::invalid_argument( - "[dequantize] The matrix should be given as a uint32"); - } - - // Packing into uint32 - int out_size = w.shape(-1) * 32 / bits; - - if (out_size != scales.shape(-1) * group_size) { - std::ostringstream msg; - msg << "[dequantize] Shape of scales does not match the matrix " - << "given the quantization parameters. Provided matrix of shape " - << w.shape() << " and scales of shape " << scales.shape() << "."; - throw std::invalid_argument(msg.str()); - } - - auto out_type = dtype.has_value() ? *dtype : bfloat16; - auto out = w; + auto fallback = + [wshape = std::move(wshape), + sshape = std::move(sshape), + group_size, + bits, + out_type, + s](const std::vector& inputs) mutable -> std::vector { + auto out = inputs[0]; + auto scales = inputs[1]; if (bits == 4) { auto lut = array( { @@ -4527,15 +4535,68 @@ array dequantize( out = from_fp8(view(out, uint8, s), out_type, s); } out = reshape(out, {-1, group_size}, s); - auto flat_scales = reshape(scales, {-1, 1}, s); + scales = reshape(scales, {-1, 1}, s); if (group_size == 16) { - flat_scales = from_fp8(flat_scales, out_type, s); + scales = from_fp8(scales, out_type, s); } else { - flat_scales = - subtract(astype(flat_scales, out_type, s), array(127, out_type), s); - flat_scales = power(array(2.0f, out_type), flat_scales, s); + scales = subtract(astype(scales, out_type, s), array(127, out_type), s); + scales = power(array(2.0f, out_type), scales, s); } - return reshape(multiply(out, flat_scales, s), wshape, s); + return {reshape(multiply(out, scales, s), wshape, s)}; + }; + if (s.device == Device::gpu) { + auto out_shape = w.shape(); + out_shape.back() = out_size; + return array( + std::move(out_shape), + out_type, + std::make_shared( + s, fallback, group_size, bits, mode, true), + {w, scales}); + } + return fallback({w, scales})[0]; +} + +array dequantize( + const array& w, + const array& scales, + const std::optional& biases /* = std::nullopt */, + int group_size /* = 64 */, + int bits /* = 4 */, + const std::string& mode /* = "affine" */, + std::optional dtype /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto [out_type, qmode] = + validate_mode_with_type("dequantize", scales, biases, dtype, mode); + if (bits <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for bits: " << bits; + throw std::invalid_argument(msg.str()); + } + if (group_size <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for group_size: " << group_size; + throw std::invalid_argument(msg.str()); + } + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[dequantize] The matrix to be dequantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + if (qmode == QuantizationMode::Affine) { + return astype( + affine_dequantize(w, scales, *biases, group_size, bits, s), + out_type, + s); + } else { + return fp_dequantize( + w, scales, group_size, bits, out_type, qmode, to_stream(s)); } } @@ -4594,7 +4655,8 @@ array gather_qmm( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "gather_qmm", x, w, scales, biases, transpose, group_size, bits); - auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode); + auto [out_type, qmode] = + validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode); out_type = promote_types(x.dtype(), out_type); if (!issubdtype(out_type, floating)) { @@ -4634,7 +4696,7 @@ array gather_qmm( out_shape.push_back(x.shape(-2)); out_shape.push_back(w_outer_dims); std::vector inputs; - if (mode == "affine") { + if (qmode == QuantizationMode::Affine) { inputs = { astype(x, out_type, s), std::move(w), @@ -4657,7 +4719,7 @@ array gather_qmm( to_stream(s), group_size, bits, - string_to_quantization_mode(mode), + qmode, transpose, sorted_indices && !rhs_indices_, sorted_indices && !lhs_indices_), diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index afd9dd5b9..976200ba3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3328,19 +3328,37 @@ std::pair, std::vector> Power::vmap( } std::string quantization_mode_to_string(QuantizationMode mode) { - if (mode == QuantizationMode::Affine) { - return "affine"; - } else { - return "mxfp4"; + switch (mode) { + case QuantizationMode::Affine: + return "affine"; + case QuantizationMode::Mxfp4: + return "mxfp4"; + case QuantizationMode::Mxfp8: + return "mxfp8"; + case QuantizationMode::Nvfp4: + default: + return "nvfp4"; } } -QuantizationMode string_to_quantization_mode(const std::string& mode) { +QuantizationMode string_to_quantization_mode( + const std::string& mode, + std::string_view tag /* = "" */) { if (mode == "affine") { return QuantizationMode::Affine; - } else { + } else if (mode == "mxfp4") { return QuantizationMode::Mxfp4; + } else if (mode == "mxfp8") { + return QuantizationMode::Mxfp8; + } else if (mode == "nvfp4") { + return QuantizationMode::Nvfp4; } + std::string msg; + if (!tag.empty()) { + msg += "[" + std::string(tag) + "]"; + } + msg += " Invalid quantization mode '" + mode + "'."; + throw std::invalid_argument(msg); } std::pair, std::vector> QuantizedMatmul::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index 2a843a0e4..a1ad2425c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -151,10 +151,12 @@ class UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; -enum class QuantizationMode { Affine, Mxfp4 }; +enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; std::string quantization_mode_to_string(QuantizationMode mode); -QuantizationMode string_to_quantization_mode(const std::string& mode); +QuantizationMode string_to_quantization_mode( + const std::string& mode, + std::string_view error_tag = ""); class Abs : public UnaryPrimitive { public: diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5fe867b37..d828cc69b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -61,13 +61,18 @@ class TestQuantized(mlx_tests.MLXTestCase): mx.quantize(w, group_size=64, bits=4, mode="mxfp4") w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") - with self.assertRaises(ValueError): mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") with self.assertRaises(ValueError): mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") + # Invalid output type + with self.assertRaises(ValueError): + mx.dequantize( + w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32 + ) + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))