diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 008001c50..71c687d85 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix( } // namespace -void fast::AffineQuantize::eval_gpu( +void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { - nvtx3::scoped_range r("AffineQuantize::eval_gpu"); + nvtx3::scoped_range r("Quantize::eval_gpu"); auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 09e6c4ef3..dba82c6dc 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -129,7 +129,7 @@ NO_CPU(Inverse) NO_CPU(View) namespace fast { -NO_CPU_MULTI(AffineQuantize) +NO_CPU_MULTI(Quantize) } // namespace fast namespace distributed { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index dfe5b57f1..22a0c8acc 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -154,7 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) -NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(Quantize) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/export.cpp b/mlx/export.cpp index 7099f4864..19944dfc4 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -335,7 +335,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Cholesky), SERIALIZE_PRIMITIVE(Eig), SERIALIZE_PRIMITIVE(Eigh), - SERIALIZE_PRIMITIVE(AffineQuantize), + SERIALIZE_PRIMITIVE(Quantize), SERIALIZE_PRIMITIVE(RMSNorm), SERIALIZE_PRIMITIVE(RMSNormVJP), SERIALIZE_PRIMITIVE(LayerNorm), diff --git a/mlx/fast.cpp b/mlx/fast.cpp index b8d622253..2917b1584 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -806,211 +806,14 @@ array pack_and_quantize( return packed_w; } -std::tuple -affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { - auto s = to_stream(s_); - - if (group_size != 32 && group_size != 64 && group_size != 128) { - std::ostringstream msg; - msg << "[quantize] The requested group size " << group_size - << " is not supported. The supported group sizes are 32, 64, and 128."; - throw std::invalid_argument(msg.str()); - } - - if (bits < 2 || bits > 8 || bits == 7) { - std::ostringstream msg; - msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; - throw std::invalid_argument(msg.str()); - } - - if (w.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - if ((w.shape(-1) % group_size) != 0) { - std::ostringstream msg; - msg << "[quantize] The last dimension of the matrix needs to be divisible by " - << "the quantization group size " << group_size - << ". However the provided " << " matrix has shape " << w.shape(); - throw std::invalid_argument(msg.str()); - } - - auto fallback = [group_size, bits, s]( - const std::vector& inputs) -> std::vector { - auto& w = inputs[0]; - auto wshape = w.shape(); - wshape.back() = -1; - - array zero(0, float32); - array n_bins((1 << bits) - 1, float32); // 2**bits - 1 - array eps(1e-7, float32); - - array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); - - array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - w_max = astype(w_max, float32, s); - w_min = astype(w_min, float32, s); - - array mask = greater(abs(w_min, s), abs(w_max, s), s); - array scales = - maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales, s), s); - array edge = where(mask, w_min, w_max, s); - array q0 = round(divide(edge, scales, s), s); - scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge, s); - - packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); - - scales = astype(scales, w.dtype(), s); - biases = astype(biases, w.dtype(), s); - return { - reshape(packed_w, wshape, s), - reshape(scales, wshape, s), - reshape(biases, wshape, s), - }; - }; - - auto wq_shape = w.shape(); - wq_shape.back() = w.shape(-1) * bits / 32; - auto sshape = w.shape(); - sshape.back() = w.shape(-1) / group_size; - auto outputs = array::make_arrays( - {std::move(wq_shape), sshape, sshape}, - {uint32, w.dtype(), w.dtype()}, - std::make_shared(s, fallback, group_size, bits, false), - {w}); - return {outputs[0], outputs[1], outputs[2]}; -} - -array affine_dequantize( - const array& w, - const array& scales, - const array& biases, - int group_size, - int bits, - StreamOrDevice s_) { - if (bits <= 0) { - std::ostringstream msg; - msg << "[dequantize] Invalid value for bits: " << bits; - throw std::invalid_argument(msg.str()); - } - if (group_size <= 0) { - std::ostringstream msg; - msg << "[dequantize] Invalid value for group_size: " << group_size; - throw std::invalid_argument(msg.str()); - } - if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - auto wshape = w.shape(); - auto sshape = scales.shape(); - auto bshape = biases.shape(); - wshape.back() = -1; - sshape.back() = -1; - bshape.back() = -1; - - if (wshape != sshape || wshape != bshape) { - throw std::invalid_argument( - "[dequantize] Shape of scales and biases does not match the matrix"); - } - - if (w.dtype() != uint32) { - throw std::invalid_argument( - "[dequantize] The matrix should be given as a uint32"); - } - - // Packing into uint32 - int out_size = w.shape(-1) * 32 / bits; - - if (out_size != scales.shape(-1) * group_size) { - std::ostringstream msg; - msg << "[dequantize] Shape of scales and biases does not match the matrix " - << "given the quantization parameters. Provided matrix of shape " - << w.shape() << " and scales/biases of shape " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits << "."; - throw std::invalid_argument(msg.str()); - } - - auto s = to_stream(s_); - - auto fallback = - [wshape = std::move(wshape), - sshape = std::move(sshape), - group_size, - bits, - s](const std::vector& inputs) mutable -> std::vector { - auto w = inputs[0]; - auto& scales = inputs[1]; - auto& biases = inputs[2]; - if (is_power_of_2(bits)) { - std::vector parts; - for (int start = 0; start < 32; start += bits) { - int shift_left = 32 - (start + bits); - int shift_right = shift_left + start; - - parts.push_back(expand_dims( - right_shift( - left_shift(w, array(32 - (start + bits), uint32), s), - array(32 - bits, uint32), - s), - -1, - s)); - } - w = concatenate(parts, -1, s); - } else { - w = expand_dims(w, /* axis= */ -1, s); - w = bitwise_and( - right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s); - auto new_shape = w.shape(); - new_shape[new_shape.size() - 2] = -1; - new_shape.back() = bits; - w = reshape(w, new_shape, s); - array shifts = arange(bits, uint32, s); - w = sum( - left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s); - } - - // Dequantize - wshape.push_back(group_size); - w = reshape(w, wshape, s); - w = multiply(w, expand_dims(scales, -1, s), s); - w = add(w, expand_dims(biases, -1, s), s); - w = reshape(w, sshape, s); - - return {w}; - }; - - if (s.device == Device::gpu) { - auto out_shape = w.shape(); - out_shape.back() = out_size; - return array( - std::move(out_shape), - scales.dtype(), - std::make_shared(s, fallback, group_size, bits, true), - {w, scales, biases}); - } - return fallback({w, scales, biases})[0]; -} - -bool AffineQuantize::is_equivalent(const Primitive& other) const { - const AffineQuantize& p_other = static_cast(other); +bool Quantize::is_equivalent(const Primitive& other) const { + const Quantize& p_other = static_cast(other); return ( p_other.group_size_ == group_size_ && p_other.bits_ == bits_ && - p_other.dequantize_ == dequantize_); + p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_); } -std::vector AffineQuantize::output_shapes( - const std::vector& inputs) { +std::vector Quantize::output_shapes(const std::vector& inputs) { auto& w = inputs[0]; if (dequantize_) { auto out_size = w.shape(-1) * 32 / bits_; @@ -1022,8 +825,12 @@ std::vector AffineQuantize::output_shapes( wq_shape.back() = w.shape(-1) * bits_ / 32; auto sshape = w.shape(); sshape.back() = w.shape(-1) / group_size_; - auto bshape = sshape; - return {std::move(wq_shape), std::move(sshape), std::move(bshape)}; + if (inputs.size() == 2) { + return {std::move(wq_shape), std::move(sshape)}; + } else { + auto bshape = sshape; + return {std::move(wq_shape), std::move(sshape), std::move(bshape)}; + } } } diff --git a/mlx/fast.h b/mlx/fast.h index d154e4753..10f9ced96 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -52,20 +52,6 @@ array scaled_dot_product_attention( const std::vector& mask_arrs = {}, StreamOrDevice s = {}); -std::tuple affine_quantize( - const array& w, - int group_size = 64, - int bits = 4, - StreamOrDevice s = {}); - -array affine_dequantize( - const array& w, - const array& scales, - const array& biases, - int group_size = 64, - int bits = 4, - StreamOrDevice s = {}); - using TemplateArg = std::variant; using ScalarArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index e0e83f726..d6ab26018 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -245,17 +245,19 @@ class ScaledDotProductAttention : public Custom { bool do_causal_; }; -class AffineQuantize : public Custom { +class Quantize : public Custom { public: - explicit AffineQuantize( + explicit Quantize( Stream stream, std::function(std::vector)> fallback, int group_size, int bits, + const std::string& mode, bool dequantize) : Custom(stream, fallback), group_size_(group_size), bits_(bits), + mode_(mode), dequantize_(dequantize) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) @@ -264,17 +266,18 @@ class AffineQuantize : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_NAME(AffineQuantize); + DEFINE_NAME(Quantize); bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(nullptr, group_size_, bits_, dequantize_); + return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_); } private: int group_size_; int bits_; + std::string mode_; bool dequantize_; }; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4b1df9908..bde34d54b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -10,7 +10,7 @@ #include #include -#include "mlx/fast.h" +#include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" @@ -76,7 +76,7 @@ std::pair extract_quantized_matmul_dims( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, bool transpose, int group_size, int bits) { @@ -87,11 +87,11 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } - if (scales.shape() != biases.shape()) { + if (biases && scales.shape() != biases->shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " << "Received scales with shape " << scales.shape() - << " and biases with " << biases.shape(); + << " and biases with " << biases->shape(); throw std::invalid_argument(msg.str()); } @@ -99,9 +99,9 @@ std::pair extract_quantized_matmul_dims( w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) { std::ostringstream msg; msg << "[" << tag - << "] Weight, scales and biases should have the same batch shape. " + << "] Weight and scales should have the same batch shape. " << "Received weight with shape " << w.shape() << ", scales with " - << scales.shape() << " and biases with " << biases.shape(); + << scales.shape() << "."; throw std::invalid_argument(msg.str()); } @@ -4021,11 +4021,50 @@ array conv_general( {in, wt}); } +void validate_mode(std::string_view tag, const std::string& mode) { + if (mode != "affine" && mode != "mxfp4") { + std::ostringstream msg; + msg << "[" << tag << "] Invalid quantization mode '" << mode << "'."; + throw std::invalid_argument(msg.str()); + } +} + +Dtype validate_mode_with_type( + std::string_view tag, + const array& scales, + const std::optional& biases, + const std::string& mode) { + validate_mode(tag, mode); + if (mode == "affine") { + if (!biases) { + std::ostringstream msg; + msg << "[" << tag << "] Biases must be provided for affine quantization."; + throw std::invalid_argument(msg.str()); + } + auto dtype = result_type(scales, *biases); + if (!issubdtype(dtype, floating)) { + std::ostringstream msg; + msg << "[" << tag << "] Only real floating types are supported but " + << "scales.dtype() == " << scales.dtype() + << " and biases.dtype() == " << biases->dtype() << "."; + throw std::invalid_argument(msg.str()); + } + return dtype; + } + if (biases) { + std::ostringstream msg; + msg << "[" << tag << "] Biases must be null for quantization mode '" << mode + << "'."; + throw std::invalid_argument(msg.str()); + } + return bfloat16; +} + array quantized_matmul( array x, array w, array scales, - array biases, + std::optional biases /* = std::nullopt */, bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, @@ -4035,17 +4074,23 @@ array quantized_matmul( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - auto dtype = result_type(x, scales, biases); + auto dtype = + validate_mode_with_type("quantized_matmul", scales, biases, mode); + dtype = promote_types(x.dtype(), dtype); + if (!issubdtype(dtype, floating)) { std::ostringstream msg; msg << "[quantized_matmul] Only real floating types are supported but " - << "the passed types where x.dtype() == " << x.dtype() - << ", scales.dtype() == " << scales.dtype() - << " and biases.dtype() == " << biases.dtype(); + << "x.dtype() == " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } - std::vector inputs = { - astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)}; + std::vector inputs; + if (mode == "affine") { + inputs = { + astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; + } else { + throw std::invalid_argument("ERROR!"); + } if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); @@ -4061,31 +4106,413 @@ array quantized_matmul( std::move(inputs)); } -std::tuple quantize( +array pack_and_quantize( + array& packed_w, + const array& scales, + const array& biases, + int bits, + const Stream& s) { + int el_per_int = 32 / bits; + array zero(0, packed_w.dtype()); + array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1 + packed_w = astype( + clip( + round(divide(subtract(packed_w, biases, s), scales, s), s), + zero, + n_bins, + s), + uint32, + s); + if (is_power_of_2(bits)) { + array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); + packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); + packed_w = sum( + multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); + } else { + // This is slow but we have fast GPU/CPU versions of this function so we + // shouldn't be here often. + packed_w = expand_dims(packed_w, /* axis= */ -1, s); + packed_w = bitwise_and( + right_shift(packed_w, arange(bits, uint32, s), s), + array({1}, uint32), + s); + auto new_shape = packed_w.shape(); + new_shape[new_shape.size() - 2] = -1; + new_shape.back() = 32; + packed_w = reshape(packed_w, new_shape, s); + array shifts = arange(32, uint32, s); + packed_w = + sum(left_shift(packed_w, shifts, s), + /* axis= */ -1, + /* keepdims= */ false, + s); + } + return packed_w; +} + +std::vector +affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { + auto s = to_stream(s_); + if (group_size != 32 && group_size != 64 && group_size != 128) { + std::ostringstream msg; + msg << "[quantize] The requested group size " << group_size + << " is not supported. The supported group sizes are 32, 64, and 128."; + throw std::invalid_argument(msg.str()); + } + + if (bits < 2 || bits > 8 || bits == 7) { + std::ostringstream msg; + msg << "[quantize] The requested number of bits " << bits + << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; + throw std::invalid_argument(msg.str()); + } + + auto fallback = [group_size, bits, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; + auto wshape = w.shape(); + wshape.back() = -1; + + array zero(0, float32); + array n_bins((1 << bits) - 1, float32); // 2**bits - 1 + array eps(1e-7, float32); + + array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); + + array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + w_max = astype(w_max, float32, s); + w_min = astype(w_min, float32, s); + + array mask = greater(abs(w_min, s), abs(w_max, s), s); + array scales = + maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales, s), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + array biases = where(equal(q0, zero, s), zero, edge, s); + + packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); + + scales = astype(scales, w.dtype(), s); + biases = astype(biases, w.dtype(), s); + return { + reshape(packed_w, wshape, s), + reshape(scales, wshape, s), + reshape(biases, wshape, s), + }; + }; + + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) * bits / 32; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size; + return array::make_arrays( + {std::move(wq_shape), sshape, sshape}, + {uint32, w.dtype(), w.dtype()}, + std::make_shared( + s, fallback, group_size, bits, "affine", false), + {w}); +} + +std::vector quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { - return fast::affine_quantize(w, group_size, bits, s); + validate_mode("quantize", mode); + if (!issubdtype(w.dtype(), floating)) { + std::ostringstream msg; + msg << "[quantize] Only real floating types can be quantized " + << "but w has type " << w.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + if ((w.shape(-1) % group_size) != 0) { + std::ostringstream msg; + msg << "[quantize] The last dimension of the matrix needs to be divisible by " + << "the quantization group size " << group_size + << ". However the provided " << " matrix has shape " << w.shape(); + throw std::invalid_argument(msg.str()); + } + + if (mode == "affine") { + return affine_quantize(w, group_size, bits, s); + } else { + if (group_size != 32) { + std::ostringstream msg; + msg << "[quantize] mxfp4 quantization requires group size 32 " + << "but got " << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != 4) { + std::ostringstream msg; + msg << "[quantize] mxfp4 quantization requires bits to be 4 " + << "but got " << bits << "."; + throw std::invalid_argument(msg.str()); + } + + auto lut = array({ + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }); + lut = astype(lut, w.dtype(), s); + + auto new_shape = w.shape(); + new_shape.back() = -1; + auto wq = reshape(w, {-1, group_size}, s); + auto scales = + divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s); + scales = astype(log2(scales, s), int32, s); + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); + auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); + wq = reshape(wq, {-1, group_size / 8, 8}, s); + wq = sum(multiply(wq, shifts, s), -1, false, s); + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); + return {std::move(wq), std::move(scales)}; + } +} + +array affine_dequantize( + const array& w, + const array& scales, + const array& biases, + int group_size, + int bits, + StreamOrDevice s_) { + if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + auto wshape = w.shape(); + auto sshape = scales.shape(); + auto bshape = biases.shape(); + wshape.back() = -1; + sshape.back() = -1; + bshape.back() = -1; + + if (wshape != sshape || wshape != bshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales and biases does not match the matrix"); + } + + // Packing into uint32 + int out_size = w.shape(-1) * 32 / bits; + + if (out_size != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales and biases does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales/biases of shape " << scales.shape() + << " with group_size=" << group_size << " and bits=" << bits << "."; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + + auto fallback = + [wshape = std::move(wshape), + sshape = std::move(sshape), + group_size, + bits, + s](const std::vector& inputs) mutable -> std::vector { + auto w = inputs[0]; + auto& scales = inputs[1]; + auto& biases = inputs[2]; + if (is_power_of_2(bits)) { + std::vector parts; + for (int start = 0; start < 32; start += bits) { + int shift_left = 32 - (start + bits); + int shift_right = shift_left + start; + + parts.push_back(expand_dims( + right_shift( + left_shift(w, array(32 - (start + bits), uint32), s), + array(32 - bits, uint32), + s), + -1, + s)); + } + w = concatenate(parts, -1, s); + } else { + w = expand_dims(w, /* axis= */ -1, s); + w = bitwise_and( + right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s); + auto new_shape = w.shape(); + new_shape[new_shape.size() - 2] = -1; + new_shape.back() = bits; + w = reshape(w, new_shape, s); + array shifts = arange(bits, uint32, s); + w = sum( + left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s); + } + + // Dequantize + wshape.push_back(group_size); + w = reshape(w, wshape, s); + w = multiply(w, expand_dims(scales, -1, s), s); + w = add(w, expand_dims(biases, -1, s), s); + w = reshape(w, sshape, s); + + return {w}; + }; + + if (s.device == Device::gpu) { + auto out_shape = w.shape(); + out_shape.back() = out_size; + return array( + std::move(out_shape), + scales.dtype(), + std::make_shared( + s, fallback, group_size, bits, "affine", true), + {w, scales, biases}); + } + return fallback({w, scales, biases})[0]; } array dequantize( const array& w, const array& scales, - const array& biases, + const std::optional& biases /* = std::nullopt */, int group_size /* = 64 */, int bits /* = 4 */, const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { - return fast::affine_dequantize(w, scales, biases, group_size, bits, s); + validate_mode_with_type("dequantize", scales, biases, mode); + if (bits <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for bits: " << bits; + throw std::invalid_argument(msg.str()); + } + if (group_size <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for group_size: " << group_size; + throw std::invalid_argument(msg.str()); + } + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + + if (mode == "affine") { + return affine_dequantize(w, scales, *biases, group_size, bits, s); + } else { + if (group_size != 32) { + std::ostringstream msg; + msg << "[dequantize] mxfp4 quantization requires group size 32 " + << "but got " << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != 4) { + std::ostringstream msg; + msg << "[dequantize] mxfp4 quantization requires bits to be 4 " + << "but got " << bits << "."; + throw std::invalid_argument(msg.str()); + } + + if (w.ndim() < 2 || scales.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + auto wshape = w.shape(); + auto sshape = scales.shape(); + wshape.back() = -1; + sshape.back() = -1; + + if (wshape != sshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales does not match the matrix"); + } + + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + + // Packing into uint32 + int out_size = w.shape(-1) * 32 / bits; + + if (out_size != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales of shape " << scales.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto dtype = bfloat16; + auto lut = array( + { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }, + dtype); + + auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s); + + auto idx_lo = bitwise_and(what, array(0x0F, int8), s); + auto idx_hi = right_shift(what, array(4, int8), s); + auto lo = gather(lut, idx_lo, 0, {1}, s); + auto hi = gather(lut, idx_hi, 0, {1}, s); + what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s); + auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s); + exponent = reshape(exponent, {-1, 1}, s); + return reshape( + multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s); + } } array gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases /* = std::nullopt */, std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, bool transpose /* = true */, @@ -4102,6 +4529,16 @@ array gather_qmm( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "gather_qmm", x, w, scales, biases, transpose, group_size, bits); + auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode); + out_type = promote_types(x.dtype(), out_type); + + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[gather_qmm] Only real floating types are supported but " + << "x.dtype() == " << x.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); array rhs_indices = indices_or_default(rhs_indices_, w, s); @@ -4117,6 +4554,12 @@ array gather_qmm( throw std::invalid_argument( "[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral."); } + if (x.ndim() < 2) { + std::ostringstream msg; + msg << "[gather_qmm] Non-quantized input must have at least two" + << " dimensions but got input with shape " << x.shape() << "."; + throw std::invalid_argument(msg.str()); + } lhs_indices = astype(lhs_indices, uint32, s); rhs_indices = astype(rhs_indices, uint32, s); @@ -4126,9 +4569,6 @@ array gather_qmm( out_shape.push_back(x.shape(-2)); out_shape.push_back(w_outer_dims); - // and output type - auto out_type = result_type(x, scales, biases); - return array( std::move(out_shape), out_type, @@ -4143,7 +4583,7 @@ array gather_qmm( {astype(x, out_type, s), std::move(w), astype(scales, out_type, s), - astype(biases, out_type, s), + astype(*biases, out_type, s), std::move(lhs_indices), std::move(rhs_indices)}); } diff --git a/mlx/ops.h b/mlx/ops.h index 37ab9800c..826f6d47b 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1322,7 +1322,7 @@ array quantized_matmul( array x, array w, array scales, - array biases, + std::optional biases = std::nullopt, bool transpose = true, int group_size = 64, int bits = 4, @@ -1330,7 +1330,7 @@ array quantized_matmul( StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ -std::tuple quantize( +std::vector quantize( const array& w, int group_size = 64, int bits = 4, @@ -1341,7 +1341,7 @@ std::tuple quantize( array dequantize( const array& w, const array& scales, - const array& biases, + const std::optional& biases = std::nullopt, int group_size = 64, int bits = 4, const std::string& mode = "affine", @@ -1352,7 +1352,7 @@ array gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases = std::nullopt, std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, bool transpose = true, diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 5894d7c15..c85b55e90 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -98,9 +98,11 @@ class QuantizedEmbedding(Module): # Initialize the quantized weight scale = math.sqrt(1 / dims) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) - self.weight, self.scales, self.biases = mx.quantize( - weight, group_size, bits, mode=mode - ) + self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode) + if mode == "affine": + self.scales, self.biases = scales_biases + else: + self.scales = scales_biases self.num_embeddings = num_embeddings self.dims = dims @@ -108,10 +110,11 @@ class QuantizedEmbedding(Module): self.freeze() def __call__(self, x): + biases = self.get("biases") return mx.dequantize( self["weight"][x], scales=self["scales"][x], - biases=self["biases"][x], + biases=biases[x] if biases is not None else None, group_size=self.group_size, bits=self.bits, mode=self.mode, @@ -128,7 +131,7 @@ class QuantizedEmbedding(Module): x, self["weight"], scales=self["scales"], - biases=self["biases"], + biases=self.get("biases"), transpose=True, group_size=self.group_size, bits=self.bits, @@ -207,9 +210,11 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, self.scales, self.biases = mx.quantize( - weight, group_size, bits, mode=mode - ) + self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode) + if mode == "affine": + self.scales, self.biases = scales_biases + else: + self.scales = scales_biases # And bias if needed if bias: @@ -231,7 +236,7 @@ class QuantizedLinear(Module): x, self["weight"], scales=self["scales"], - biases=self["biases"], + biases=self.get("biases"), transpose=True, group_size=self.group_size, bits=self.bits, @@ -252,12 +257,17 @@ class QuantizedLinear(Module): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits) - ql.weight, ql.scales, ql.biases = mx.quantize( + ql.weight, scales_biases = mx.quantize( linear_layer.weight, group_size, bits, mode=mode, ) + if mode == "affine": + ql.scales, ql.biases = scales_biases + else: + ql.scales = scales_biases + if "bias" in linear_layer: ql.bias = linear_layer.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 090ec842f..cb0add614 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4149,7 +4149,7 @@ void init_ops(nb::module_& m) { nb::arg(), nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, @@ -4157,7 +4157,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4168,7 +4168,8 @@ void init_ops(nb::module_& m) { x (array): Input array w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + biases (array, optional): The biases to use per ``group_size`` + elements of ``w``. Default: ``None``. transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. @@ -4216,11 +4217,11 @@ void init_ops(nb::module_& m) { mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: - tuple: A tuple containing + tuple: A tuple with either two or three elements containing: * w_q (array): The quantized version of ``w`` - * scales (array): The scale to multiply each element with, namely :math:`s` - * biases (array): The biases to add to each element, namely :math:`\beta` + * scales (array): The quantization scales + * biases (array): The quantization biases (returned for `mode=="affine"`). Notes: The currently supported quantization mode is `"affine"`. @@ -4252,14 +4253,14 @@ void init_ops(nb::module_& m) { &mx::dequantize, nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "group_size"_a = 64, "bits"_a = 4, "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. @@ -4268,7 +4269,8 @@ void init_ops(nb::module_& m) { Args: w (array): Matrix to be quantized scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + biases (array, optional): The biases to use per ``group_size`` + elements of ``w``. Default: ``None``. group_size (int, optional): The size of the group in ``w`` that shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in @@ -4294,7 +4296,7 @@ void init_ops(nb::module_& m) { nb::arg(), nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), "transpose"_a = true, @@ -4305,7 +4307,7 @@ void init_ops(nb::module_& m) { "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4321,7 +4323,8 @@ void init_ops(nb::module_& m) { x (array): Input array w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + biases (array, optional): The biases to use per ``group_size`` + elements of ``w``. Default: ``None``. lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. transpose (bool, optional): Defines whether to multiply with the diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 90a57221f..eb2826dc8 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase): a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) + def test_mxfp4_quantize_dequantize(self): + lut = mx.array( + [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + ) + w = lut[mx.random.randint(0, 16, shape=(128, 512))] + w = w.reshape(-1, 32) + w[:, 0] = 6 + w = (w + 3e-6).astype(mx.bfloat16) + + # Invalid bits / group size + with self.assertRaises(ValueError): + mx.quantize(w, bits=3, group_size=32, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.quantize(w, group_size=64, bits=4, mode="mxfp4") + + w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") + + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + self.assertTrue(mx.all(w_hat == 0)) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -233,6 +283,71 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 2e-3) + def test_mode_error_cases(self): + w = mx.random.normal(shape=(256, 256)) + x = mx.random.normal(shape=(1, 256)) + + # Invalid mode + with self.assertRaises(ValueError): + mx.quantize(w, mode="xyz") + + wq, scales, biases = mx.quantize(w, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz") + + with self.assertRaises(ValueError): + mx.quantized_matmul( + x, wq, scales, biases, bits=4, group_size=32, mode="xyz" + ) + + rhs_indices = mx.array(0) + with self.assertRaises(ValueError): + mx.gather_qmm( + x, + wq, + scales, + biases, + rhs_indices=rhs_indices, + bits=4, + group_size=32, + mode="xyz", + ) + + # Only quantize floating point types + with self.assertRaises(ValueError): + mx.quantize(mx.zeros((128, 128), mx.int32)) + + with self.assertRaises(ValueError): + mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4") + + # Must have bias for affine + with self.assertRaises(ValueError): + mx.dequantize(wq, scales, None, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.gather_qmm( + x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32 + ) + + # Must be floating point + x = mx.zeros(shape=(256,), dtype=mx.int32) + scales = mx.zeros(scales.shape, dtype=mx.int32) + biases = mx.zeros(scales.shape, dtype=mx.int32) + with self.assertRaises(ValueError): + mx.dequantize(wq, scales, biases, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.gather_qmm( + x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32 + ) + def test_throw(self): x = mx.random.normal(shape=(10, 512)) w = mx.random.normal(shape=(32, 512))