mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fast cuda kernel for mx/nv quantization
This commit is contained in:
@@ -51,6 +51,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
@@ -306,7 +306,7 @@ void affine_dequantize(
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
218
mlx/backend/cuda/quantized/fp_quantize.cu
Normal file
218
mlx/backend/cuda/quantized/fp_quantize.cu
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cuda_fp4.h>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
__device__ uint8_t operator()(float x) {
|
||||
if constexpr (bits == 8) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
} else {
|
||||
return __nv_fp4_e2m1(x).__x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
if constexpr (bits == 8) {
|
||||
return float(*(__nv_fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(__nv_fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
size_t out_index = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t in_index = out_index;
|
||||
if (in_index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float w_thread = w[in_index];
|
||||
|
||||
cg::greater<float> max_op;
|
||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||
|
||||
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
// Convert to mx scale or nv scale
|
||||
using ScaleType =
|
||||
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||
auto s = ScaleType(scale);
|
||||
uint8_t q_scale = s.__x;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales
|
||||
size_t gindex = in_index / group_size;
|
||||
if (in_index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
|
||||
uint8_t output = 0;
|
||||
uint8_t val = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
output = val;
|
||||
if (bits == 4) {
|
||||
uint8_t sval = warp.shfl_down(val, 1);
|
||||
output |= sval << bits;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (out_index % pack_factor == 0) {
|
||||
out[out_index / pack_factor] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t oindex = offset * pack_factor;
|
||||
|
||||
if (oindex >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t gindex = oindex / group_size;
|
||||
using ScaleType =
|
||||
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||
auto scale = float(((ScaleType*)(scales))[gindex]);
|
||||
|
||||
out += oindex;
|
||||
|
||||
uint val = w[offset];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < pack_factor; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void fp_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
enc.set_input_array(w);
|
||||
enc.set_output_array(wq);
|
||||
enc.set_output_array(scales);
|
||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||
}
|
||||
bool large = w.size() > UINT_MAX;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<uint8_t>(),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Quantize::eval_gpu] Can not quantize input with type float64.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void fp_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
int packs_per_int = 8 / bits;
|
||||
|
||||
size_t size = w.size() / packs_per_int;
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
|
||||
enc.set_input_array(wq);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_dequantize<T, 32, 8, true>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_dequantize<T, 16, 4, false>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
w.data<T>(),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
|
||||
if (dequantize_) {
|
||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
|
||||
}
|
||||
} else {
|
||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,4 +24,22 @@ void affine_dequantize(
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void fp_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void fp_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
352
mlx/ops.cpp
352
mlx/ops.cpp
@@ -3938,22 +3938,22 @@ array conv_general(
|
||||
{in, wt});
|
||||
}
|
||||
|
||||
void validate_mode(std::string_view tag, const std::string& mode) {
|
||||
if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" &&
|
||||
mode != "nvfp4") {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Dtype validate_mode_with_type(
|
||||
std::pair<Dtype, QuantizationMode> validate_mode_with_type(
|
||||
std::string_view tag,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases,
|
||||
const std::optional<Dtype> out_type,
|
||||
const std::string& mode) {
|
||||
validate_mode(tag, mode);
|
||||
if (mode == "affine") {
|
||||
auto qmode = string_to_quantization_mode(mode, tag);
|
||||
// TODO add tests for out_type
|
||||
if (out_type.has_value() && !issubdtype(*out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Only real floating types are supported but "
|
||||
<< "output dtype == " << *out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
if (!biases) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Biases must be provided for affine quantization.";
|
||||
@@ -3967,7 +3967,11 @@ Dtype validate_mode_with_type(
|
||||
<< " and biases.dtype() == " << biases->dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return dtype;
|
||||
if (out_type.has_value()) {
|
||||
return {*out_type, qmode};
|
||||
} else {
|
||||
return {dtype, qmode};
|
||||
}
|
||||
}
|
||||
if (biases) {
|
||||
std::ostringstream msg;
|
||||
@@ -3975,7 +3979,11 @@ Dtype validate_mode_with_type(
|
||||
<< "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return bfloat16;
|
||||
if (out_type.has_value()) {
|
||||
return {*out_type, qmode};
|
||||
} else {
|
||||
return {bfloat16, qmode};
|
||||
}
|
||||
}
|
||||
|
||||
array quantized_matmul(
|
||||
@@ -3992,8 +4000,8 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
auto dtype =
|
||||
validate_mode_with_type("quantized_matmul", scales, biases, mode);
|
||||
auto [dtype, qmode] = validate_mode_with_type(
|
||||
"quantized_matmul", scales, biases, std::nullopt, mode);
|
||||
dtype = promote_types(x.dtype(), dtype);
|
||||
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
@@ -4003,7 +4011,7 @@ array quantized_matmul(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<array> inputs;
|
||||
if (mode == "affine") {
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
|
||||
} else {
|
||||
@@ -4020,11 +4028,7 @@ array quantized_matmul(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s),
|
||||
group_size,
|
||||
bits,
|
||||
string_to_quantization_mode(mode),
|
||||
transpose),
|
||||
to_stream(s), group_size, bits, qmode, transpose),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
@@ -4138,53 +4142,31 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
{w});
|
||||
}
|
||||
|
||||
std::vector<array> quantize(
|
||||
std::vector<array> fp_quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_mode("quantize", mode);
|
||||
if (!issubdtype(w.dtype(), floating)) {
|
||||
int group_size,
|
||||
int bits,
|
||||
QuantizationMode mode,
|
||||
Stream s) {
|
||||
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
||||
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Only real floating types can be quantized "
|
||||
<< "but w has type " << w.dtype() << ".";
|
||||
msg << "[quantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires group size " << expected_gs << " but got "
|
||||
<< group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
msg << "[quantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires bits to be " << expected_bits << " but got "
|
||||
<< bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided "
|
||||
<< " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
return affine_quantize(w, group_size, bits, s);
|
||||
} else {
|
||||
int expected_gs = (mode[0] == 'm') ? 32 : 16;
|
||||
int expected_bits = (mode.back() == '8') ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] " << mode << " quantization requires group size "
|
||||
<< expected_gs << " but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] " << mode << " quantization requires bits to be "
|
||||
<< expected_bits << " but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto fallback = [bits = bits, group_size = group_size, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto& w = inputs[0];
|
||||
float maxval = (bits == 4) ? 6.0f : 448.0f;
|
||||
auto new_shape = w.shape();
|
||||
new_shape.back() = -1;
|
||||
@@ -4235,6 +4217,57 @@ std::vector<array> quantize(
|
||||
wq = reshape(wq, new_shape, s);
|
||||
scales = reshape(scales, new_shape, s);
|
||||
return {std::move(wq), std::move(scales)};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
return array::make_arrays(
|
||||
{std::move(wq_shape), std::move(sshape)},
|
||||
{uint32, uint8},
|
||||
std::make_shared<fast::Quantize>(
|
||||
s, fallback, group_size, bits, mode, false),
|
||||
{w});
|
||||
}
|
||||
return fallback({w});
|
||||
}
|
||||
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto qmode = string_to_quantization_mode(mode, "quantize");
|
||||
if (!issubdtype(w.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Only real floating types can be quantized "
|
||||
<< "but w has type " << w.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided "
|
||||
<< " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
return affine_quantize(w, group_size, bits, s);
|
||||
} else {
|
||||
return fp_quantize(w, group_size, bits, qmode, to_stream(s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4245,16 +4278,13 @@ array affine_dequantize(
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
@@ -4338,88 +4368,66 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
array fp_dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases /* = std::nullopt */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
std::optional<Dtype> dtype /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_mode_with_type("dequantize", scales, biases, mode);
|
||||
if (bits <= 0) {
|
||||
int group_size,
|
||||
int bits,
|
||||
Dtype out_type,
|
||||
QuantizationMode mode,
|
||||
Stream s) {
|
||||
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
||||
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
msg << "[dequantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires group size " << expected_gs << " but got "
|
||||
<< group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (group_size <= 0) {
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
msg << "[dequantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires bits to be " << expected_bits << " but got "
|
||||
<< bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.dtype() != uint32) {
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
if (wshape.size() != sshape.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
auto out = affine_dequantize(w, scales, *biases, group_size, bits, s);
|
||||
if (dtype) {
|
||||
out = astype(out, *dtype, s);
|
||||
}
|
||||
return out;
|
||||
} else {
|
||||
int expected_gs = (mode[0] == 'm') ? 32 : 16;
|
||||
int expected_bits = (mode.back() == '8') ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] " << mode << " quantization requires group size "
|
||||
<< expected_gs << " but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] " << mode << " quantization requires bits to be "
|
||||
<< expected_bits << " but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
|
||||
if (w.ndim() < 2 || scales.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be dequantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (wshape != sshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (wshape != sshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = dtype.has_value() ? *dtype : bfloat16;
|
||||
auto out = w;
|
||||
auto fallback =
|
||||
[wshape = std::move(wshape),
|
||||
sshape = std::move(sshape),
|
||||
group_size,
|
||||
bits,
|
||||
out_type,
|
||||
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||
auto out = inputs[0];
|
||||
auto scales = inputs[1];
|
||||
if (bits == 4) {
|
||||
auto lut = array(
|
||||
{
|
||||
@@ -4451,15 +4459,68 @@ array dequantize(
|
||||
out = from_fp8(view(out, uint8, s), out_type, s);
|
||||
}
|
||||
out = reshape(out, {-1, group_size}, s);
|
||||
auto flat_scales = reshape(scales, {-1, 1}, s);
|
||||
scales = reshape(scales, {-1, 1}, s);
|
||||
if (group_size == 16) {
|
||||
flat_scales = from_fp8(flat_scales, out_type, s);
|
||||
scales = from_fp8(scales, out_type, s);
|
||||
} else {
|
||||
flat_scales =
|
||||
subtract(astype(flat_scales, out_type, s), array(127, out_type), s);
|
||||
flat_scales = power(array(2.0f, out_type), flat_scales, s);
|
||||
scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
|
||||
scales = power(array(2.0f, out_type), scales, s);
|
||||
}
|
||||
return reshape(multiply(out, flat_scales, s), wshape, s);
|
||||
return {reshape(multiply(out, scales, s), wshape, s)};
|
||||
};
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = out_size;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<fast::Quantize>(
|
||||
s, fallback, group_size, bits, mode, true),
|
||||
{w, scales});
|
||||
}
|
||||
return fallback({w, scales})[0];
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases /* = std::nullopt */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
std::optional<Dtype> dtype /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto [out_type, qmode] =
|
||||
validate_mode_with_type("dequantize", scales, biases, dtype, mode);
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (group_size <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] The matrix to be dequantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
return astype(
|
||||
affine_dequantize(w, scales, *biases, group_size, bits, s),
|
||||
out_type,
|
||||
s);
|
||||
} else {
|
||||
return fp_dequantize(
|
||||
w, scales, group_size, bits, out_type, qmode, to_stream(s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4518,7 +4579,8 @@ array gather_qmm(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode);
|
||||
auto [out_type, qmode] =
|
||||
validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode);
|
||||
out_type = promote_types(x.dtype(), out_type);
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
@@ -4558,7 +4620,7 @@ array gather_qmm(
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
std::vector<array> inputs;
|
||||
if (mode == "affine") {
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
inputs = {
|
||||
astype(x, out_type, s),
|
||||
std::move(w),
|
||||
@@ -4581,7 +4643,7 @@ array gather_qmm(
|
||||
to_stream(s),
|
||||
group_size,
|
||||
bits,
|
||||
string_to_quantization_mode(mode),
|
||||
qmode,
|
||||
transpose,
|
||||
sorted_indices && !rhs_indices_,
|
||||
sorted_indices && !lhs_indices_),
|
||||
|
||||
@@ -3329,19 +3329,37 @@ std::pair<std::vector<array>, std::vector<int>> Power::vmap(
|
||||
}
|
||||
|
||||
std::string quantization_mode_to_string(QuantizationMode mode) {
|
||||
if (mode == QuantizationMode::Affine) {
|
||||
return "affine";
|
||||
} else {
|
||||
return "mxfp4";
|
||||
switch (mode) {
|
||||
case QuantizationMode::Affine:
|
||||
return "affine";
|
||||
case QuantizationMode::Mxfp4:
|
||||
return "mxfp4";
|
||||
case QuantizationMode::Mxfp8:
|
||||
return "mxfp8";
|
||||
case QuantizationMode::Nvfp4:
|
||||
default:
|
||||
return "nvfp4";
|
||||
}
|
||||
}
|
||||
|
||||
QuantizationMode string_to_quantization_mode(const std::string& mode) {
|
||||
QuantizationMode string_to_quantization_mode(
|
||||
const std::string& mode,
|
||||
std::string_view tag /* = "" */) {
|
||||
if (mode == "affine") {
|
||||
return QuantizationMode::Affine;
|
||||
} else {
|
||||
} else if (mode == "mxfp4") {
|
||||
return QuantizationMode::Mxfp4;
|
||||
} else if (mode == "mxfp8") {
|
||||
return QuantizationMode::Mxfp8;
|
||||
} else if (mode == "nvfp4") {
|
||||
return QuantizationMode::Nvfp4;
|
||||
}
|
||||
std::string msg;
|
||||
if (!tag.empty()) {
|
||||
msg += "[" + std::string(tag) + "]";
|
||||
}
|
||||
msg += " Invalid quantization mode '" + mode + "'.";
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
|
||||
|
||||
@@ -151,10 +151,12 @@ class UnaryPrimitive : public Primitive {
|
||||
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
||||
};
|
||||
|
||||
enum class QuantizationMode { Affine, Mxfp4 };
|
||||
enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
|
||||
|
||||
std::string quantization_mode_to_string(QuantizationMode mode);
|
||||
QuantizationMode string_to_quantization_mode(const std::string& mode);
|
||||
QuantizationMode string_to_quantization_mode(
|
||||
const std::string& mode,
|
||||
std::string_view error_tag = "");
|
||||
|
||||
class Abs : public UnaryPrimitive {
|
||||
public:
|
||||
|
||||
@@ -61,13 +61,18 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
# Invalid output type
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(
|
||||
w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32
|
||||
)
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user