mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-30 10:56:41 +08:00
mxfp4 quantize/dequantize + start of optional biases
This commit is contained in:
parent
8ec8d44ee6
commit
4cf90c9762
@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix(
|
||||
|
||||
} // namespace
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
void fast::Quantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
|
||||
nvtx3::scoped_range r("Quantize::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
@ -129,7 +129,7 @@ NO_CPU(Inverse)
|
||||
NO_CPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_CPU_MULTI(AffineQuantize)
|
||||
NO_CPU_MULTI(Quantize)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
@ -154,7 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(Quantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
|
@ -335,7 +335,7 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(Cholesky),
|
||||
SERIALIZE_PRIMITIVE(Eig),
|
||||
SERIALIZE_PRIMITIVE(Eigh),
|
||||
SERIALIZE_PRIMITIVE(AffineQuantize),
|
||||
SERIALIZE_PRIMITIVE(Quantize),
|
||||
SERIALIZE_PRIMITIVE(RMSNorm),
|
||||
SERIALIZE_PRIMITIVE(RMSNormVJP),
|
||||
SERIALIZE_PRIMITIVE(LayerNorm),
|
||||
|
213
mlx/fast.cpp
213
mlx/fast.cpp
@ -806,211 +806,14 @@ array pack_and_quantize(
|
||||
return packed_w;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array>
|
||||
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 32, 64, and 128.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (bits < 2 || bits > 8 || bits == 7) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [group_size, bits, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto& w = inputs[0];
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
array zero(0, float32);
|
||||
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
|
||||
array eps(1e-7, float32);
|
||||
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
w_max = astype(w_max, float32, s);
|
||||
w_min = astype(w_min, float32, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales =
|
||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales, s), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||
|
||||
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
||||
|
||||
scales = astype(scales, w.dtype(), s);
|
||||
biases = astype(biases, w.dtype(), s);
|
||||
return {
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
reshape(biases, wshape, s),
|
||||
};
|
||||
};
|
||||
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
auto outputs = array::make_arrays(
|
||||
{std::move(wq_shape), sshape, sshape},
|
||||
{uint32, w.dtype(), w.dtype()},
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w});
|
||||
return {outputs[0], outputs[1], outputs[2]};
|
||||
}
|
||||
|
||||
array affine_dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (group_size <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
|
||||
if (wshape != sshape || wshape != bshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
|
||||
auto fallback =
|
||||
[wshape = std::move(wshape),
|
||||
sshape = std::move(sshape),
|
||||
group_size,
|
||||
bits,
|
||||
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||
auto w = inputs[0];
|
||||
auto& scales = inputs[1];
|
||||
auto& biases = inputs[2];
|
||||
if (is_power_of_2(bits)) {
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += bits) {
|
||||
int shift_left = 32 - (start + bits);
|
||||
int shift_right = shift_left + start;
|
||||
|
||||
parts.push_back(expand_dims(
|
||||
right_shift(
|
||||
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||
array(32 - bits, uint32),
|
||||
s),
|
||||
-1,
|
||||
s));
|
||||
}
|
||||
w = concatenate(parts, -1, s);
|
||||
} else {
|
||||
w = expand_dims(w, /* axis= */ -1, s);
|
||||
w = bitwise_and(
|
||||
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
|
||||
auto new_shape = w.shape();
|
||||
new_shape[new_shape.size() - 2] = -1;
|
||||
new_shape.back() = bits;
|
||||
w = reshape(w, new_shape, s);
|
||||
array shifts = arange(bits, uint32, s);
|
||||
w = sum(
|
||||
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
|
||||
}
|
||||
|
||||
// Dequantize
|
||||
wshape.push_back(group_size);
|
||||
w = reshape(w, wshape, s);
|
||||
w = multiply(w, expand_dims(scales, -1, s), s);
|
||||
w = add(w, expand_dims(biases, -1, s), s);
|
||||
w = reshape(w, sshape, s);
|
||||
|
||||
return {w};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = out_size;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
scales.dtype(),
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
||||
{w, scales, biases});
|
||||
}
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
||||
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
||||
bool Quantize::is_equivalent(const Primitive& other) const {
|
||||
const Quantize& p_other = static_cast<const Quantize&>(other);
|
||||
return (
|
||||
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
|
||||
p_other.dequantize_ == dequantize_);
|
||||
p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_);
|
||||
}
|
||||
|
||||
std::vector<Shape> AffineQuantize::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {
|
||||
auto& w = inputs[0];
|
||||
if (dequantize_) {
|
||||
auto out_size = w.shape(-1) * 32 / bits_;
|
||||
@ -1022,8 +825,12 @@ std::vector<Shape> AffineQuantize::output_shapes(
|
||||
wq_shape.back() = w.shape(-1) * bits_ / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size_;
|
||||
auto bshape = sshape;
|
||||
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
|
||||
if (inputs.size() == 2) {
|
||||
return {std::move(wq_shape), std::move(sshape)};
|
||||
} else {
|
||||
auto bshape = sshape;
|
||||
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
14
mlx/fast.h
14
mlx/fast.h
@ -52,20 +52,6 @@ array scaled_dot_product_attention(
|
||||
const std::vector<array>& mask_arrs = {},
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array affine_dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
using TemplateArg = std::variant<int, bool, Dtype>;
|
||||
using ScalarArg = std::variant<bool, int, float>;
|
||||
|
||||
|
@ -245,17 +245,19 @@ class ScaledDotProductAttention : public Custom {
|
||||
bool do_causal_;
|
||||
};
|
||||
|
||||
class AffineQuantize : public Custom {
|
||||
class Quantize : public Custom {
|
||||
public:
|
||||
explicit AffineQuantize(
|
||||
explicit Quantize(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
bool dequantize)
|
||||
: Custom(stream, fallback),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
mode_(mode),
|
||||
dequantize_(dequantize) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@ -264,17 +266,18 @@ class AffineQuantize : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_NAME(AffineQuantize);
|
||||
DEFINE_NAME(Quantize);
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(nullptr, group_size_, bits_, dequantize_);
|
||||
return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
int bits_;
|
||||
std::string mode_;
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
|
484
mlx/ops.cpp
484
mlx/ops.cpp
@ -10,7 +10,7 @@
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
@ -76,7 +76,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits) {
|
||||
@ -87,11 +87,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (scales.shape() != biases.shape()) {
|
||||
if (biases && scales.shape() != biases->shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||
<< "Received scales with shape " << scales.shape()
|
||||
<< " and biases with " << biases.shape();
|
||||
<< " and biases with " << biases->shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -99,9 +99,9 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] Weight, scales and biases should have the same batch shape. "
|
||||
<< "] Weight and scales should have the same batch shape. "
|
||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||
<< scales.shape() << " and biases with " << biases.shape();
|
||||
<< scales.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -4021,11 +4021,50 @@ array conv_general(
|
||||
{in, wt});
|
||||
}
|
||||
|
||||
void validate_mode(std::string_view tag, const std::string& mode) {
|
||||
if (mode != "affine" && mode != "mxfp4") {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Dtype validate_mode_with_type(
|
||||
std::string_view tag,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases,
|
||||
const std::string& mode) {
|
||||
validate_mode(tag, mode);
|
||||
if (mode == "affine") {
|
||||
if (!biases) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Biases must be provided for affine quantization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto dtype = result_type(scales, *biases);
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Only real floating types are supported but "
|
||||
<< "scales.dtype() == " << scales.dtype()
|
||||
<< " and biases.dtype() == " << biases->dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
if (biases) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Biases must be null for quantization mode '" << mode
|
||||
<< "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return bfloat16;
|
||||
}
|
||||
|
||||
array quantized_matmul(
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
std::optional<array> biases /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
@ -4035,17 +4074,23 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
auto dtype =
|
||||
validate_mode_with_type("quantized_matmul", scales, biases, mode);
|
||||
dtype = promote_types(x.dtype(), dtype);
|
||||
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
<< ", scales.dtype() == " << scales.dtype()
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
<< "x.dtype() == " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<array> inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)};
|
||||
std::vector<array> inputs;
|
||||
if (mode == "affine") {
|
||||
inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
|
||||
} else {
|
||||
throw std::invalid_argument("ERROR!");
|
||||
}
|
||||
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
inputs = broadcast_arrays(inputs, {-2, -1}, s);
|
||||
@ -4061,31 +4106,413 @@ array quantized_matmul(
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
array pack_and_quantize(
|
||||
array& packed_w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
int el_per_int = 32 / bits;
|
||||
array zero(0, packed_w.dtype());
|
||||
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
||||
packed_w = astype(
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins,
|
||||
s),
|
||||
uint32,
|
||||
s);
|
||||
if (is_power_of_2(bits)) {
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
} else {
|
||||
// This is slow but we have fast GPU/CPU versions of this function so we
|
||||
// shouldn't be here often.
|
||||
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
|
||||
packed_w = bitwise_and(
|
||||
right_shift(packed_w, arange(bits, uint32, s), s),
|
||||
array({1}, uint32),
|
||||
s);
|
||||
auto new_shape = packed_w.shape();
|
||||
new_shape[new_shape.size() - 2] = -1;
|
||||
new_shape.back() = 32;
|
||||
packed_w = reshape(packed_w, new_shape, s);
|
||||
array shifts = arange(32, uint32, s);
|
||||
packed_w =
|
||||
sum(left_shift(packed_w, shifts, s),
|
||||
/* axis= */ -1,
|
||||
/* keepdims= */ false,
|
||||
s);
|
||||
}
|
||||
return packed_w;
|
||||
}
|
||||
|
||||
std::vector<array>
|
||||
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
auto s = to_stream(s_);
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 32, 64, and 128.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (bits < 2 || bits > 8 || bits == 7) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [group_size, bits, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto& w = inputs[0];
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
array zero(0, float32);
|
||||
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
|
||||
array eps(1e-7, float32);
|
||||
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
w_max = astype(w_max, float32, s);
|
||||
w_min = astype(w_min, float32, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales =
|
||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales, s), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||
|
||||
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
||||
|
||||
scales = astype(scales, w.dtype(), s);
|
||||
biases = astype(biases, w.dtype(), s);
|
||||
return {
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
reshape(biases, wshape, s),
|
||||
};
|
||||
};
|
||||
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
return array::make_arrays(
|
||||
{std::move(wq_shape), sshape, sshape},
|
||||
{uint32, w.dtype(), w.dtype()},
|
||||
std::make_shared<fast::Quantize>(
|
||||
s, fallback, group_size, bits, "affine", false),
|
||||
{w});
|
||||
}
|
||||
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
validate_mode("quantize", mode);
|
||||
if (!issubdtype(w.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Only real floating types can be quantized "
|
||||
<< "but w has type " << w.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
return affine_quantize(w, group_size, bits, s);
|
||||
} else {
|
||||
if (group_size != 32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] mxfp4 quantization requires group size 32 "
|
||||
<< "but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] mxfp4 quantization requires bits to be 4 "
|
||||
<< "but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto lut = array({
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
});
|
||||
lut = astype(lut, w.dtype(), s);
|
||||
|
||||
auto new_shape = w.shape();
|
||||
new_shape.back() = -1;
|
||||
auto wq = reshape(w, {-1, group_size}, s);
|
||||
auto scales =
|
||||
divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s);
|
||||
scales = astype(log2(scales, s), int32, s);
|
||||
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
||||
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
||||
wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
|
||||
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
|
||||
wq = reshape(wq, {-1, group_size / 8, 8}, s);
|
||||
wq = sum(multiply(wq, shifts, s), -1, false, s);
|
||||
wq = reshape(wq, new_shape, s);
|
||||
scales = reshape(scales, new_shape, s);
|
||||
return {std::move(wq), std::move(scales)};
|
||||
}
|
||||
}
|
||||
|
||||
array affine_dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
|
||||
if (wshape != sshape || wshape != bshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
|
||||
auto fallback =
|
||||
[wshape = std::move(wshape),
|
||||
sshape = std::move(sshape),
|
||||
group_size,
|
||||
bits,
|
||||
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||
auto w = inputs[0];
|
||||
auto& scales = inputs[1];
|
||||
auto& biases = inputs[2];
|
||||
if (is_power_of_2(bits)) {
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += bits) {
|
||||
int shift_left = 32 - (start + bits);
|
||||
int shift_right = shift_left + start;
|
||||
|
||||
parts.push_back(expand_dims(
|
||||
right_shift(
|
||||
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||
array(32 - bits, uint32),
|
||||
s),
|
||||
-1,
|
||||
s));
|
||||
}
|
||||
w = concatenate(parts, -1, s);
|
||||
} else {
|
||||
w = expand_dims(w, /* axis= */ -1, s);
|
||||
w = bitwise_and(
|
||||
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
|
||||
auto new_shape = w.shape();
|
||||
new_shape[new_shape.size() - 2] = -1;
|
||||
new_shape.back() = bits;
|
||||
w = reshape(w, new_shape, s);
|
||||
array shifts = arange(bits, uint32, s);
|
||||
w = sum(
|
||||
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
|
||||
}
|
||||
|
||||
// Dequantize
|
||||
wshape.push_back(group_size);
|
||||
w = reshape(w, wshape, s);
|
||||
w = multiply(w, expand_dims(scales, -1, s), s);
|
||||
w = add(w, expand_dims(biases, -1, s), s);
|
||||
w = reshape(w, sshape, s);
|
||||
|
||||
return {w};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = out_size;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
scales.dtype(),
|
||||
std::make_shared<fast::Quantize>(
|
||||
s, fallback, group_size, bits, "affine", true),
|
||||
{w, scales, biases});
|
||||
}
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases /* = std::nullopt */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
||||
validate_mode_with_type("dequantize", scales, biases, mode);
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (group_size <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
return affine_dequantize(w, scales, *biases, group_size, bits, s);
|
||||
} else {
|
||||
if (group_size != 32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] mxfp4 quantization requires group size 32 "
|
||||
<< "but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] mxfp4 quantization requires bits to be 4 "
|
||||
<< "but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2 || scales.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
|
||||
if (wshape != sshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto dtype = bfloat16;
|
||||
auto lut = array(
|
||||
{
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
},
|
||||
dtype);
|
||||
|
||||
auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s);
|
||||
|
||||
auto idx_lo = bitwise_and(what, array(0x0F, int8), s);
|
||||
auto idx_hi = right_shift(what, array(4, int8), s);
|
||||
auto lo = gather(lut, idx_lo, 0, {1}, s);
|
||||
auto hi = gather(lut, idx_hi, 0, {1}, s);
|
||||
what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s);
|
||||
auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s);
|
||||
exponent = reshape(exponent, {-1, 1}, s);
|
||||
return reshape(
|
||||
multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s);
|
||||
}
|
||||
}
|
||||
|
||||
array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases /* = std::nullopt */,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
@ -4102,6 +4529,16 @@ array gather_qmm(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode);
|
||||
out_type = promote_types(x.dtype(), out_type);
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[gather_qmm] Only real floating types are supported but "
|
||||
<< "x.dtype() == " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
||||
@ -4117,6 +4554,12 @@ array gather_qmm(
|
||||
throw std::invalid_argument(
|
||||
"[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
if (x.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[gather_qmm] Non-quantized input must have at least two"
|
||||
<< " dimensions but got input with shape " << x.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
lhs_indices = astype(lhs_indices, uint32, s);
|
||||
rhs_indices = astype(rhs_indices, uint32, s);
|
||||
@ -4126,9 +4569,6 @@ array gather_qmm(
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
// and output type
|
||||
auto out_type = result_type(x, scales, biases);
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
@ -4143,7 +4583,7 @@ array gather_qmm(
|
||||
{astype(x, out_type, s),
|
||||
std::move(w),
|
||||
astype(scales, out_type, s),
|
||||
astype(biases, out_type, s),
|
||||
astype(*biases, out_type, s),
|
||||
std::move(lhs_indices),
|
||||
std::move(rhs_indices)});
|
||||
}
|
||||
|
@ -1322,7 +1322,7 @@ array quantized_matmul(
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
std::optional<array> biases = std::nullopt,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
@ -1330,7 +1330,7 @@ array quantized_matmul(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
@ -1341,7 +1341,7 @@ std::tuple<array, array, array> quantize(
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases = std::nullopt,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
@ -1352,7 +1352,7 @@ array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases = std::nullopt,
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
bool transpose = true,
|
||||
|
@ -98,9 +98,11 @@ class QuantizedEmbedding(Module):
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
self.scales = scales_biases
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
@ -108,10 +110,11 @@ class QuantizedEmbedding(Module):
|
||||
self.freeze()
|
||||
|
||||
def __call__(self, x):
|
||||
biases = self.get("biases")
|
||||
return mx.dequantize(
|
||||
self["weight"][x],
|
||||
scales=self["scales"][x],
|
||||
biases=self["biases"][x],
|
||||
biases=biases[x] if biases is not None else None,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
@ -128,7 +131,7 @@ class QuantizedEmbedding(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases"),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
@ -207,9 +210,11 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
self.scales = scales_biases
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@ -231,7 +236,7 @@ class QuantizedLinear(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases"),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
@ -252,12 +257,17 @@ class QuantizedLinear(Module):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
ql.weight, scales_biases = mx.quantize(
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
ql.scales = scales_biases
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
|
@ -4149,7 +4149,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
@ -4157,7 +4157,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@ -4168,7 +4168,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
@ -4216,11 +4217,11 @@ void init_ops(nb::module_& m) {
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
tuple: A tuple with either two or three elements containing:
|
||||
|
||||
* w_q (array): The quantized version of ``w``
|
||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
* scales (array): The quantization scales
|
||||
* biases (array): The quantization biases (returned for `mode=="affine"`).
|
||||
|
||||
Notes:
|
||||
The currently supported quantization mode is `"affine"`.
|
||||
@ -4252,14 +4253,14 @@ void init_ops(nb::module_& m) {
|
||||
&mx::dequantize,
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
@ -4268,7 +4269,8 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
@ -4294,7 +4296,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
@ -4305,7 +4307,7 @@ void init_ops(nb::module_& m) {
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@ -4321,7 +4323,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
|
@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
self.assertTrue(mx.all(a_hat == 0))
|
||||
|
||||
def test_mxfp4_quantize_dequantize(self):
|
||||
lut = mx.array(
|
||||
[
|
||||
+0.0,
|
||||
+0.5,
|
||||
+1.0,
|
||||
+1.5,
|
||||
+2.0,
|
||||
+3.0,
|
||||
+4.0,
|
||||
+6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
)
|
||||
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
|
||||
w = w.reshape(-1, 32)
|
||||
w[:, 0] = 6
|
||||
w = (w + 3e-6).astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@ -233,6 +283,71 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_mode_error_cases(self):
|
||||
w = mx.random.normal(shape=(256, 256))
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
|
||||
# Invalid mode
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, mode="xyz")
|
||||
|
||||
wq, scales, biases = mx.quantize(w, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(
|
||||
x, wq, scales, biases, bits=4, group_size=32, mode="xyz"
|
||||
)
|
||||
|
||||
rhs_indices = mx.array(0)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x,
|
||||
wq,
|
||||
scales,
|
||||
biases,
|
||||
rhs_indices=rhs_indices,
|
||||
bits=4,
|
||||
group_size=32,
|
||||
mode="xyz",
|
||||
)
|
||||
|
||||
# Only quantize floating point types
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(mx.zeros((128, 128), mx.int32))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4")
|
||||
|
||||
# Must have bias for affine
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, None, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32
|
||||
)
|
||||
|
||||
# Must be floating point
|
||||
x = mx.zeros(shape=(256,), dtype=mx.int32)
|
||||
scales = mx.zeros(scales.shape, dtype=mx.int32)
|
||||
biases = mx.zeros(scales.shape, dtype=mx.int32)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, biases, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32
|
||||
)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
|
Loading…
Reference in New Issue
Block a user