fast cuda kernel for mx/nv quantization

This commit is contained in:
Awni Hannun
2025-10-21 11:49:58 -07:00
parent c00ccf7404
commit c961a3a557
9 changed files with 492 additions and 161 deletions

View File

@@ -51,6 +51,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)

View File

@@ -306,7 +306,7 @@ void affine_dequantize(
enc.set_input_array(scales); enc.set_input_array(scales);
enc.set_input_array(biases); enc.set_input_array(biases);
enc.set_output_array(w); enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) { dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) { dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) { dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;

View File

@@ -0,0 +1,218 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp4.h>
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
if constexpr (bits == 8) {
return float(*(__nv_fp8_e4m3*)(&x));
} else {
return float(*(__nv_fp4_e2m1*)(&x));
}
}
};
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t out_index = tidx + grid_dim_x * size_t(tidy);
size_t in_index = out_index;
if (in_index >= size) {
return;
}
float w_thread = w[in_index];
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.__x;
scale = float(s);
// Write out the scales
size_t gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = 0;
uint8_t val = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
output = val;
if (bits == 4) {
uint8_t sval = warp.shfl_down(val, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (out_index % pack_factor == 0) {
out[out_index / pack_factor] = output;
}
}
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto scale = float(((ScaleType*)(scales))[gindex]);
out += oindex;
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}
} // namespace cu
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<uint8_t>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not quantize input with type float64.");
}
});
}
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
constexpr int uint8_per_uint32 = 4;
int packs_per_int = 8 / bits;
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_dequantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_dequantize<T, 16, 4, false>;
}
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
w.data<T>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
}
});
}
} // namespace mlx::core

View File

@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
if (dequantize_) { if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s); auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0]; auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes())); w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
}
} else { } else {
auto w = ensure_row_contiguous(inputs[0], enc, s); auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0]; auto& wq = outputs[0];
auto& scales = outputs[1]; auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes())); wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes())); scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes())); if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
}
} }
} }

View File

@@ -24,4 +24,22 @@ void affine_dequantize(
cu::CommandEncoder& enc, cu::CommandEncoder& enc,
const Stream& s); const Stream& s);
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -4017,22 +4017,22 @@ array conv_general(
{in, wt}); {in, wt});
} }
void validate_mode(std::string_view tag, const std::string& mode) { std::pair<Dtype, QuantizationMode> validate_mode_with_type(
if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" &&
mode != "nvfp4") {
std::ostringstream msg;
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
throw std::invalid_argument(msg.str());
}
}
Dtype validate_mode_with_type(
std::string_view tag, std::string_view tag,
const array& scales, const array& scales,
const std::optional<array>& biases, const std::optional<array>& biases,
const std::optional<Dtype> out_type,
const std::string& mode) { const std::string& mode) {
validate_mode(tag, mode); auto qmode = string_to_quantization_mode(mode, tag);
if (mode == "affine") { // TODO add tests for out_type
if (out_type.has_value() && !issubdtype(*out_type, floating)) {
std::ostringstream msg;
msg << "[" << tag << "] Only real floating types are supported but "
<< "output dtype == " << *out_type << ".";
throw std::invalid_argument(msg.str());
}
if (qmode == QuantizationMode::Affine) {
if (!biases) { if (!biases) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] Biases must be provided for affine quantization."; msg << "[" << tag << "] Biases must be provided for affine quantization.";
@@ -4046,7 +4046,11 @@ Dtype validate_mode_with_type(
<< " and biases.dtype() == " << biases->dtype() << "."; << " and biases.dtype() == " << biases->dtype() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
return dtype; if (out_type.has_value()) {
return {*out_type, qmode};
} else {
return {dtype, qmode};
}
} }
if (biases) { if (biases) {
std::ostringstream msg; std::ostringstream msg;
@@ -4054,7 +4058,11 @@ Dtype validate_mode_with_type(
<< "'."; << "'.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
return bfloat16; if (out_type.has_value()) {
return {*out_type, qmode};
} else {
return {bfloat16, qmode};
}
} }
array quantized_matmul( array quantized_matmul(
@@ -4071,8 +4079,8 @@ array quantized_matmul(
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits); "quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
auto dtype = auto [dtype, qmode] = validate_mode_with_type(
validate_mode_with_type("quantized_matmul", scales, biases, mode); "quantized_matmul", scales, biases, std::nullopt, mode);
dtype = promote_types(x.dtype(), dtype); dtype = promote_types(x.dtype(), dtype);
if (!issubdtype(dtype, floating)) { if (!issubdtype(dtype, floating)) {
@@ -4082,7 +4090,7 @@ array quantized_matmul(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<array> inputs; std::vector<array> inputs;
if (mode == "affine") { if (qmode == QuantizationMode::Affine) {
inputs = { inputs = {
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
} else { } else {
@@ -4099,11 +4107,7 @@ array quantized_matmul(
std::move(out_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), to_stream(s), group_size, bits, qmode, transpose),
group_size,
bits,
string_to_quantization_mode(mode),
transpose),
std::move(inputs)); std::move(inputs));
} }
@@ -4217,53 +4221,31 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
{w}); {w});
} }
std::vector<array> quantize( std::vector<array> fp_quantize(
const array& w, const array& w,
int group_size /* = 64 */, int group_size,
int bits /* = 4 */, int bits,
const std::string& mode /* = "affine" */, QuantizationMode mode,
StreamOrDevice s /* = {} */) { Stream s) {
validate_mode("quantize", mode); int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
if (!issubdtype(w.dtype(), floating)) { int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
if (group_size != expected_gs) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] Only real floating types can be quantized " msg << "[quantize] " << quantization_mode_to_string(mode)
<< "but w has type " << w.dtype() << "."; << " quantization requires group size " << expected_gs << " but got "
<< group_size << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (bits != expected_bits) {
if (w.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension " msg << "[quantize] " << quantization_mode_to_string(mode)
<< "but it has only " << w.ndim() << "."; << " quantization requires bits to be " << expected_bits << " but got "
<< bits << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto fallback = [bits = bits, group_size = group_size, s](
if ((w.shape(-1) % group_size) != 0) { const std::vector<array>& inputs) -> std::vector<array> {
std::ostringstream msg; auto& w = inputs[0];
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided "
<< " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
if (mode == "affine") {
return affine_quantize(w, group_size, bits, s);
} else {
int expected_gs = (mode[0] == 'm') ? 32 : 16;
int expected_bits = (mode.back() == '8') ? 8 : 4;
if (group_size != expected_gs) {
std::ostringstream msg;
msg << "[quantize] " << mode << " quantization requires group size "
<< expected_gs << " but got " << group_size << ".";
throw std::invalid_argument(msg.str());
}
if (bits != expected_bits) {
std::ostringstream msg;
msg << "[quantize] " << mode << " quantization requires bits to be "
<< expected_bits << " but got " << bits << ".";
throw std::invalid_argument(msg.str());
}
float maxval = (bits == 4) ? 6.0f : 448.0f; float maxval = (bits == 4) ? 6.0f : 448.0f;
auto new_shape = w.shape(); auto new_shape = w.shape();
new_shape.back() = -1; new_shape.back() = -1;
@@ -4314,6 +4296,57 @@ std::vector<array> quantize(
wq = reshape(wq, new_shape, s); wq = reshape(wq, new_shape, s);
scales = reshape(scales, new_shape, s); scales = reshape(scales, new_shape, s);
return {std::move(wq), std::move(scales)}; return {std::move(wq), std::move(scales)};
};
if (s.device == Device::gpu) {
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) * bits / 32;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
return array::make_arrays(
{std::move(wq_shape), std::move(sshape)},
{uint32, uint8},
std::make_shared<fast::Quantize>(
s, fallback, group_size, bits, mode, false),
{w});
}
return fallback({w});
}
std::vector<array> quantize(
const array& w,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) {
auto qmode = string_to_quantization_mode(mode, "quantize");
if (!issubdtype(w.dtype(), floating)) {
std::ostringstream msg;
msg << "[quantize] Only real floating types can be quantized "
<< "but w has type " << w.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided "
<< " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
if (qmode == QuantizationMode::Affine) {
return affine_quantize(w, group_size, bits, s);
} else {
return fp_quantize(w, group_size, bits, qmode, to_stream(s));
} }
} }
@@ -4324,16 +4357,13 @@ array affine_dequantize(
int group_size, int group_size,
int bits, int bits,
StreamOrDevice s_) { StreamOrDevice s_) {
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape(); auto wshape = w.shape();
auto sshape = scales.shape(); auto sshape = scales.shape();
auto bshape = biases.shape(); auto bshape = biases.shape();
if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
wshape.back() = -1; wshape.back() = -1;
sshape.back() = -1; sshape.back() = -1;
bshape.back() = -1; bshape.back() = -1;
@@ -4414,88 +4444,66 @@ array affine_dequantize(
return fallback({w, scales, biases})[0]; return fallback({w, scales, biases})[0];
} }
array dequantize( array fp_dequantize(
const array& w, const array& w,
const array& scales, const array& scales,
const std::optional<array>& biases /* = std::nullopt */, int group_size,
int group_size /* = 64 */, int bits,
int bits /* = 4 */, Dtype out_type,
const std::string& mode /* = "affine" */, QuantizationMode mode,
std::optional<Dtype> dtype /* = std::nullopt */, Stream s) {
StreamOrDevice s /* = {} */) { int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
validate_mode_with_type("dequantize", scales, biases, mode); int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
if (bits <= 0) { if (group_size != expected_gs) {
std::ostringstream msg; std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits; msg << "[dequantize] " << quantization_mode_to_string(mode)
<< " quantization requires group size " << expected_gs << " but got "
<< group_size << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (group_size <= 0) { if (bits != expected_bits) {
std::ostringstream msg; std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size; msg << "[dequantize] " << quantization_mode_to_string(mode)
<< " quantization requires bits to be " << expected_bits << " but got "
<< bits << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (w.dtype() != uint32) {
auto wshape = w.shape();
auto sshape = scales.shape();
if (wshape.size() != sshape.size()) {
throw std::invalid_argument( throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32"); "[dequantize] Shape of scales does not match the matrix");
} }
if (mode == "affine") { wshape.back() = -1;
auto out = affine_dequantize(w, scales, *biases, group_size, bits, s); sshape.back() = -1;
if (dtype) {
out = astype(out, *dtype, s);
}
return out;
} else {
int expected_gs = (mode[0] == 'm') ? 32 : 16;
int expected_bits = (mode.back() == '8') ? 8 : 4;
if (group_size != expected_gs) {
std::ostringstream msg;
msg << "[quantize] " << mode << " quantization requires group size "
<< expected_gs << " but got " << group_size << ".";
throw std::invalid_argument(msg.str());
}
if (bits != expected_bits) {
std::ostringstream msg;
msg << "[quantize] " << mode << " quantization requires bits to be "
<< expected_bits << " but got " << bits << ".";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2) { if (wshape != sshape) {
std::ostringstream msg; throw std::invalid_argument(
msg << "[quantize] The matrix to be dequantized must have at least 2 dimension " "[dequantize] Shape of scales does not match the matrix");
<< "but it has only " << w.ndim() << "."; }
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape(); // Packing into uint32
auto sshape = scales.shape(); int out_size = w.shape(-1) * 32 / bits;
wshape.back() = -1;
sshape.back() = -1;
if (wshape != sshape) { if (out_size != scales.shape(-1) * group_size) {
throw std::invalid_argument( std::ostringstream msg;
"[dequantize] Shape of scales does not match the matrix"); msg << "[dequantize] Shape of scales does not match the matrix "
} << "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales of shape " << scales.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (w.dtype() != uint32) { auto fallback =
throw std::invalid_argument( [wshape = std::move(wshape),
"[dequantize] The matrix should be given as a uint32"); sshape = std::move(sshape),
} group_size,
bits,
// Packing into uint32 out_type,
int out_size = w.shape(-1) * 32 / bits; s](const std::vector<array>& inputs) mutable -> std::vector<array> {
auto out = inputs[0];
if (out_size != scales.shape(-1) * group_size) { auto scales = inputs[1];
std::ostringstream msg;
msg << "[dequantize] Shape of scales does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales of shape " << scales.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto out_type = dtype.has_value() ? *dtype : bfloat16;
auto out = w;
if (bits == 4) { if (bits == 4) {
auto lut = array( auto lut = array(
{ {
@@ -4527,15 +4535,68 @@ array dequantize(
out = from_fp8(view(out, uint8, s), out_type, s); out = from_fp8(view(out, uint8, s), out_type, s);
} }
out = reshape(out, {-1, group_size}, s); out = reshape(out, {-1, group_size}, s);
auto flat_scales = reshape(scales, {-1, 1}, s); scales = reshape(scales, {-1, 1}, s);
if (group_size == 16) { if (group_size == 16) {
flat_scales = from_fp8(flat_scales, out_type, s); scales = from_fp8(scales, out_type, s);
} else { } else {
flat_scales = scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
subtract(astype(flat_scales, out_type, s), array(127, out_type), s); scales = power(array(2.0f, out_type), scales, s);
flat_scales = power(array(2.0f, out_type), flat_scales, s);
} }
return reshape(multiply(out, flat_scales, s), wshape, s); return {reshape(multiply(out, scales, s), wshape, s)};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = out_size;
return array(
std::move(out_shape),
out_type,
std::make_shared<fast::Quantize>(
s, fallback, group_size, bits, mode, true),
{w, scales});
}
return fallback({w, scales})[0];
}
array dequantize(
const array& w,
const array& scales,
const std::optional<array>& biases /* = std::nullopt */,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
std::optional<Dtype> dtype /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto [out_type, qmode] =
validate_mode_with_type("dequantize", scales, biases, dtype, mode);
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[dequantize] The matrix to be dequantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if (qmode == QuantizationMode::Affine) {
return astype(
affine_dequantize(w, scales, *biases, group_size, bits, s),
out_type,
s);
} else {
return fp_dequantize(
w, scales, group_size, bits, out_type, qmode, to_stream(s));
} }
} }
@@ -4594,7 +4655,8 @@ array gather_qmm(
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits); "gather_qmm", x, w, scales, biases, transpose, group_size, bits);
auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode); auto [out_type, qmode] =
validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode);
out_type = promote_types(x.dtype(), out_type); out_type = promote_types(x.dtype(), out_type);
if (!issubdtype(out_type, floating)) { if (!issubdtype(out_type, floating)) {
@@ -4634,7 +4696,7 @@ array gather_qmm(
out_shape.push_back(x.shape(-2)); out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer_dims); out_shape.push_back(w_outer_dims);
std::vector<array> inputs; std::vector<array> inputs;
if (mode == "affine") { if (qmode == QuantizationMode::Affine) {
inputs = { inputs = {
astype(x, out_type, s), astype(x, out_type, s),
std::move(w), std::move(w),
@@ -4657,7 +4719,7 @@ array gather_qmm(
to_stream(s), to_stream(s),
group_size, group_size,
bits, bits,
string_to_quantization_mode(mode), qmode,
transpose, transpose,
sorted_indices && !rhs_indices_, sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_), sorted_indices && !lhs_indices_),

View File

@@ -3328,19 +3328,37 @@ std::pair<std::vector<array>, std::vector<int>> Power::vmap(
} }
std::string quantization_mode_to_string(QuantizationMode mode) { std::string quantization_mode_to_string(QuantizationMode mode) {
if (mode == QuantizationMode::Affine) { switch (mode) {
return "affine"; case QuantizationMode::Affine:
} else { return "affine";
return "mxfp4"; case QuantizationMode::Mxfp4:
return "mxfp4";
case QuantizationMode::Mxfp8:
return "mxfp8";
case QuantizationMode::Nvfp4:
default:
return "nvfp4";
} }
} }
QuantizationMode string_to_quantization_mode(const std::string& mode) { QuantizationMode string_to_quantization_mode(
const std::string& mode,
std::string_view tag /* = "" */) {
if (mode == "affine") { if (mode == "affine") {
return QuantizationMode::Affine; return QuantizationMode::Affine;
} else { } else if (mode == "mxfp4") {
return QuantizationMode::Mxfp4; return QuantizationMode::Mxfp4;
} else if (mode == "mxfp8") {
return QuantizationMode::Mxfp8;
} else if (mode == "nvfp4") {
return QuantizationMode::Nvfp4;
} }
std::string msg;
if (!tag.empty()) {
msg += "[" + std::string(tag) + "]";
}
msg += " Invalid quantization mode '" + mode + "'.";
throw std::invalid_argument(msg);
} }
std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap( std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(

View File

@@ -151,10 +151,12 @@ class UnaryPrimitive : public Primitive {
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
}; };
enum class QuantizationMode { Affine, Mxfp4 }; enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
std::string quantization_mode_to_string(QuantizationMode mode); std::string quantization_mode_to_string(QuantizationMode mode);
QuantizationMode string_to_quantization_mode(const std::string& mode); QuantizationMode string_to_quantization_mode(
const std::string& mode,
std::string_view error_tag = "");
class Abs : public UnaryPrimitive { class Abs : public UnaryPrimitive {
public: public:

View File

@@ -61,13 +61,18 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.quantize(w, group_size=64, bits=4, mode="mxfp4") mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
# Invalid output type
with self.assertRaises(ValueError):
mx.dequantize(
w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32
)
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))