Fused Affine Quantize/Dequantize ops (#1282)

* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
This commit is contained in:
Alex Barron 2024-07-29 15:11:38 -07:00 committed by GitHub
parent aa1d6cadad
commit c52d1600f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 655 additions and 400 deletions

View File

@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl(
const device T* x,
const device uint32_t* w,
@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_n_impl(
const device T* x,
const device uint32_t* w,
@ -1099,7 +1099,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@ -1131,7 +1131,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@ -1382,7 +1382,7 @@ template <
s_strides,
b_strides,
tid);
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@ -1450,6 +1450,147 @@ template <
s_strides,
b_strides,
tid);
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
static_assert(
group_size % simd_size == 0,
"Group size must be divisible by simd size.");
int in_index = index * values_per_reduce;
int out_index = index * writes_per_pack;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
T w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
T val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
w_min = simd_min(w_min);
w_max = simd_max(w_max);
T scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
T edge = side ? w_min : w_max;
T q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
T bias = at_zero ? T(0) : edge;
// Write out the scales and biases
int gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = scale;
biases[gindex] = bias;
}
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (values_per_reduce + j + i));
}
}
}
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
int in_index = index * packs_per_int;
int gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
}
}
out[index] = output;
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_dequantize(
const device uint8_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device T* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
int oindex = index * packs_per_int;
int gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint val = w[index];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[oindex + i] = scale * d + bias;
}
}

View File

@ -5,241 +5,67 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_qmv_fast(itype, group_size, bits) \
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
qmv_fast, \
itype, \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_qmv_fast_types(group_size, bits) \
instantiate_qmv_fast(float, group_size, bits) \
instantiate_qmv_fast(float16_t, group_size, bits) \
instantiate_qmv_fast(bfloat16_t, group_size, bits)
#define instantiate_quantized_types(name, group_size, bits) \
instantiate_quantized(name, float, group_size, bits) \
instantiate_quantized(name, float16_t, group_size, bits) \
instantiate_quantized(name, bfloat16_t, group_size, bits)
instantiate_qmv_fast_types(128, 2)
instantiate_qmv_fast_types(128, 4)
instantiate_qmv_fast_types(128, 8)
instantiate_qmv_fast_types( 64, 2)
instantiate_qmv_fast_types( 64, 4)
instantiate_qmv_fast_types( 64, 8)
instantiate_qmv_fast_types( 32, 2)
instantiate_qmv_fast_types( 32, 4)
instantiate_qmv_fast_types( 32, 8)
#define instantiate_quantized_groups(name, bits) \
instantiate_quantized_types(name, 128, bits) \
instantiate_quantized_types(name, 64, bits) \
instantiate_quantized_types(name, 32, bits)
#define instantiate_qmv(itype, group_size, bits) \
instantiate_kernel( \
"qmv_" #itype "_gs_" #group_size "_b_" #bits, \
qmv, \
itype, \
group_size, \
bits)
#define instantiate_quantized_all(name) \
instantiate_quantized_groups(name, 2) \
instantiate_quantized_groups(name, 4) \
instantiate_quantized_groups(name, 8)
#define instantiate_qmv_types(group_size, bits) \
instantiate_qmv(float, group_size, bits) \
instantiate_qmv(float16_t, group_size, bits) \
instantiate_qmv(bfloat16_t, group_size, bits)
instantiate_quantized_all(qmv_fast)
instantiate_quantized_all(qmv)
instantiate_quantized_all(qvm)
instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4)
instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned)
#define instantiate_qvm(itype, group_size, bits) \
instantiate_kernel( \
"qvm_" #itype "_gs_" #group_size "_b_" #bits, \
qvm, \
itype, \
group_size, \
bits)
#define instantiate_quantized_types_aligned(name, group_size, bits) \
instantiate_quantized_aligned(name, float, group_size, bits, true) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, float, group_size, bits, false) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
#define instantiate_qvm_types(group_size, bits) \
instantiate_qvm(float, group_size, bits) \
instantiate_qvm(float16_t, group_size, bits) \
instantiate_qvm(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups_aligned(name, bits) \
instantiate_quantized_types_aligned(name, 128, bits) \
instantiate_quantized_types_aligned(name, 64, bits) \
instantiate_quantized_types_aligned(name, 32, bits)
instantiate_qvm_types(128, 2)
instantiate_qvm_types(128, 4)
instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_quantized_all_aligned(name) \
instantiate_quantized_groups_aligned(name, 2) \
instantiate_quantized_groups_aligned(name, 4) \
instantiate_quantized_groups_aligned(name, 8) \
#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \
instantiate_kernel( \
"qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
qmm_t, \
itype, \
group_size, \
bits, \
aligned_N)
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float, group_size, bits, false) \
instantiate_qmm_t(float16_t, group_size, bits, false) \
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float, group_size, bits, true) \
instantiate_qmm_t(float16_t, group_size, bits, true) \
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(itype, group_size, bits) \
instantiate_kernel( \
"qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
qmm_n, \
itype, \
group_size, \
bits)
#define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float, group_size, bits) \
instantiate_qmm_n(float16_t, group_size, bits) \
instantiate_qmm_n(bfloat16_t, group_size, bits)
instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)
#define instantiate_bs_qmv_fast(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
bs_qmv_fast, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmv_fast_types(group_size, bits) \
instantiate_bs_qmv_fast(float, group_size, bits) \
instantiate_bs_qmv_fast(float16_t, group_size, bits) \
instantiate_bs_qmv_fast(bfloat16_t, group_size, bits)
instantiate_bs_qmv_fast_types(128, 2)
instantiate_bs_qmv_fast_types(128, 4)
instantiate_bs_qmv_fast_types(128, 8)
instantiate_bs_qmv_fast_types( 64, 2)
instantiate_bs_qmv_fast_types( 64, 4)
instantiate_bs_qmv_fast_types( 64, 8)
instantiate_bs_qmv_fast_types( 32, 2)
instantiate_bs_qmv_fast_types( 32, 4)
instantiate_bs_qmv_fast_types( 32, 8)
#define instantiate_bs_qmv(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qmv, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmv_types(group_size, bits) \
instantiate_bs_qmv(float, group_size, bits) \
instantiate_bs_qmv(float16_t, group_size, bits) \
instantiate_bs_qmv(bfloat16_t, group_size, bits)
instantiate_bs_qmv_types(128, 2)
instantiate_bs_qmv_types(128, 4)
instantiate_bs_qmv_types(128, 8)
instantiate_bs_qmv_types( 64, 2)
instantiate_bs_qmv_types( 64, 4)
instantiate_bs_qmv_types( 64, 8)
instantiate_bs_qmv_types( 32, 2)
instantiate_bs_qmv_types( 32, 4)
instantiate_bs_qmv_types( 32, 8)
#define instantiate_bs_qvm(itype, group_size, bits) \
instantiate_kernel( \
"bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qvm, \
itype, \
group_size, \
bits)
#define instantiate_bs_qvm_types(group_size, bits) \
instantiate_bs_qvm(float, group_size, bits) \
instantiate_bs_qvm(float16_t, group_size, bits) \
instantiate_bs_qvm(bfloat16_t, group_size, bits)
instantiate_bs_qvm_types(128, 2)
instantiate_bs_qvm_types(128, 4)
instantiate_bs_qvm_types(128, 8)
instantiate_bs_qvm_types( 64, 2)
instantiate_bs_qvm_types( 64, 4)
instantiate_bs_qvm_types( 64, 8)
instantiate_bs_qvm_types( 32, 2)
instantiate_bs_qvm_types( 32, 4)
instantiate_bs_qvm_types( 32, 8)
#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \
instantiate_kernel( \
"bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
bs_qmm_t, \
itype, \
group_size, \
bits, \
aligned_N)
#define instantiate_bs_qmm_t_types(group_size, bits) \
instantiate_bs_qmm_t(float, group_size, bits, false) \
instantiate_bs_qmm_t(float16_t, group_size, bits, false) \
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \
instantiate_bs_qmm_t(float, group_size, bits, true) \
instantiate_bs_qmm_t(float16_t, group_size, bits, true) \
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true)
instantiate_bs_qmm_t_types(128, 2)
instantiate_bs_qmm_t_types(128, 4)
instantiate_bs_qmm_t_types(128, 8)
instantiate_bs_qmm_t_types( 64, 2)
instantiate_bs_qmm_t_types( 64, 4)
instantiate_bs_qmm_t_types( 64, 8)
instantiate_bs_qmm_t_types( 32, 2)
instantiate_bs_qmm_t_types( 32, 4)
instantiate_bs_qmm_t_types( 32, 8)
#define instantiate_bs_qmm_n(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qmm_n, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmm_n_types(group_size, bits) \
instantiate_bs_qmm_n(float, group_size, bits) \
instantiate_bs_qmm_n(float16_t, group_size, bits) \
instantiate_bs_qmm_n(bfloat16_t, group_size, bits)
instantiate_bs_qmm_n_types(128, 2)
instantiate_bs_qmm_n_types(128, 4)
instantiate_bs_qmm_n_types(128, 8)
instantiate_bs_qmm_n_types( 64, 2)
instantiate_bs_qmm_n_types( 64, 4)
instantiate_bs_qmm_n_types( 64, 8)
instantiate_bs_qmm_n_types( 32, 2)
instantiate_bs_qmm_n_types( 32, 4)
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
instantiate_quantized_all_aligned(qmm_t)
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on

View File

@ -7,6 +7,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
@ -47,8 +48,8 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_fast";
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@ -270,8 +271,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@ -513,4 +514,82 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
if (!compute_scale_bias) {
auto& scales_pre = inputs[1];
auto& biases_pre = inputs[2];
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_output_array(out, 3);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
compute_encoder.set_output_array(out, 1);
compute_encoder.set_output_array(scales, 2);
compute_encoder.set_output_array(biases, 3);
}
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = "affine_quantize_scales_biases";
if (dequantize_) {
kernel_func = "affine_dequantize";
} else if (compute_scale_bias) {
kernel_func = "affine_quantize";
}
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = 8 / bits_;
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
size_t nthreads =
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto group_dims = MTL::Size(thread_group_size, 1, 1);
auto grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@ -118,6 +118,7 @@ NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
} // namespace fast
} // namespace mlx::core

View File

@ -610,4 +610,253 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
}
array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int group_size,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return packed_w;
}
std::tuple<array, array, array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_);
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
int el_per_int = 32 / bits;
if (w.shape(-1) < 32 * el_per_int) {
std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, "
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
<< "shape " << w.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto wshape = w.shape();
wshape.back() = -1;
array zero(0, w.dtype());
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
array eps(1e-7, w.dtype());
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
reshape(biases, wshape, s),
};
};
std::vector<array> outputs;
if (s.device == Device::gpu) {
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) / el_per_int;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
outputs = array::make_arrays(
{wq_shape, sshape, sshape},
{uint32, w.dtype(), w.dtype()},
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w});
} else {
outputs = fallback({w});
}
return {outputs[0], outputs[1], outputs[2]};
}
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
auto s = to_stream(s_);
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto scales = expand_dims(inputs[1], -1, s);
auto biases = expand_dims(inputs[2], -1, s);
auto wshape = w.shape();
wshape.back() = -1;
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {reshape(packed_w, wshape, s)};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) / el_per_int;
return array(
out_shape,
uint32,
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape();
auto sshape = scales.shape();
auto bshape = biases.shape();
wshape.back() = -1;
sshape.back() = -1;
bshape.back() = -1;
if (wshape != sshape || wshape != bshape) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
// Packing into uint32
int el_per_int = 32 / bits;
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
auto fallback =
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
array w_full = concatenate(parts, -1, s);
// Dequantize
wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, sshape, s);
return {w_full};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) * el_per_int;
return array(
out_shape,
scales.dtype(),
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
} // namespace mlx::core::fast

View File

@ -39,4 +39,26 @@ array scaled_dot_product_attention(
const std::optional<array>& mask = std::nullopt,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
const array& w,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
} // namespace mlx::core::fast

View File

@ -212,4 +212,34 @@ class ScaledDotProductAttention : public Custom {
bool needs_mask_;
};
class AffineQuantize : public Custom {
public:
explicit AffineQuantize(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int group_size,
int bits,
bool dequantize)
: Custom(stream, fallback),
group_size_(group_size),
bits_(bits),
dequantize_(dequantize) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(AffineQuantize);
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int group_size_;
int bits_;
bool dequantize_;
};
} // namespace mlx::core::fast

View File

@ -6,6 +6,7 @@
#include <set>
#include <sstream>
#include "mlx/fast.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
@ -3356,89 +3357,7 @@ std::tuple<array, array, array> quantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
// Compute some constants used for the quantization
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
array eps(1e-7, w.dtype());
array zero(0, w.dtype());
int el_per_int = 32 / bits;
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
shifts = reshape(shifts, {1, 1, -1}, s);
// Check that the w matrix will fill up a whole SIMD.
// This is an implementation detail which should be removed in the future but
// at least we bail out early which will result in a nice readable error.
//
// Hopefully nobody is quantizing matrices that small anyway.
if (w.shape(-1) < 32 * el_per_int) {
std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, "
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
<< "shape " << w.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Prepare the shape for the outputs.
auto wshape = w.shape();
wshape.back() = -1;
// Compute scales and biases
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
// Quantize and pack w
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple(
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
reshape(biases, wshape, s));
return fast::affine_quantize(w, group_size, bits);
}
array dequantize(
@ -3448,76 +3367,7 @@ array dequantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape();
auto sshape = scales.shape();
auto bshape = biases.shape();
wshape.back() = -1;
sshape.back() = -1;
bshape.back() = -1;
if (wshape != sshape || wshape != bshape) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
// Compute some constants for the dequantization
int el_per_int = 32 / bits;
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits << ".";
throw std::invalid_argument(msg.str());
}
// Extract the pieces from the passed quantized matrix
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
array w_full = concatenate(parts, -1, s);
// Dequantize
wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, sshape, s);
return w_full;
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
}
array gather_qmm(

View File

@ -2,6 +2,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
#include "mlx/fast.h"
@ -138,4 +139,47 @@ void init_fast(nb::module_& parent_module) {
Returns:
array: The output array.
)pbdoc");
m.def(
"affine_quantize",
nb::overload_cast<
const array&,
const array&,
const array&,
int,
int,
StreamOrDevice>(&fast::affine_quantize),
"w"_a,
"scales"_a,
"biases"_a,
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Quantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s (\hat{w_i} + \beta)
Args:
w (array): Matrix to be quantize
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
Returns:
array: The quantized version of ``w``
)pbdoc");
}

View File

@ -439,6 +439,18 @@ class TestFast(mlx_tests.MLXTestCase):
)(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
def test_affine_quantize(self):
mx.random.seed(7)
x = mx.random.uniform(shape=(4, 1024))
for bits in (2, 4, 8):
for group_size in (32, 64, 128):
with self.subTest(bits=bits, group_size=group_size):
w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size)
w_p = mx.fast.affine_quantize(
x, scales, biases, bits=bits, group_size=group_size
)
self.assertTrue(mx.allclose(w, w_p))
if __name__ == "__main__":
unittest.main()

View File

@ -12,11 +12,12 @@ class TestQuantized(mlx_tests.MLXTestCase):
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(w, gs, b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
# test quantize/dequantize 0s
a = mx.zeros((256, 512))