mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize * add full quantize kernel * fused kernel with scale/bias computation * fix docstring * fix no jit error * fix test * test fix * reduce fast api to only affine_quantize
This commit is contained in:
parent
aa1d6cadad
commit
c52d1600f0
@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
|
|||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int BM,
|
|
||||||
const int BK,
|
|
||||||
const int BN,
|
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
const bool aligned_N>
|
const bool aligned_N,
|
||||||
|
const int BM = 32,
|
||||||
|
const int BK = 32,
|
||||||
|
const int BN = 32>
|
||||||
METAL_FUNC void qmm_t_impl(
|
METAL_FUNC void qmm_t_impl(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int BM,
|
|
||||||
const int BK,
|
|
||||||
const int BN,
|
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits>
|
const int bits,
|
||||||
|
const int BM = 32,
|
||||||
|
const int BK = 32,
|
||||||
|
const int BN = 32>
|
||||||
METAL_FUNC void qmm_n_impl(
|
METAL_FUNC void qmm_n_impl(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -1099,7 +1099,7 @@ template <
|
|||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BN * BK_padded];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
|
||||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1131,7 +1131,7 @@ template <
|
|||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BK * BN_padded];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
|
|
||||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1382,7 +1382,7 @@ template <
|
|||||||
s_strides,
|
s_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
tid);
|
tid);
|
||||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1450,6 +1450,147 @@ template <
|
|||||||
s_strides,
|
s_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
tid);
|
tid);
|
||||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void affine_quantize(
|
||||||
|
const device T* w [[buffer(0)]],
|
||||||
|
device uint8_t* out [[buffer(1)]],
|
||||||
|
device T* scales [[buffer(2)]],
|
||||||
|
device T* biases [[buffer(3)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
constexpr T eps = T(1e-7);
|
||||||
|
constexpr int simd_size = 32;
|
||||||
|
constexpr int uint8_bits = 8;
|
||||||
|
constexpr T n_bins = (1 << bits) - 1;
|
||||||
|
constexpr int packs_per_int = uint8_bits / bits;
|
||||||
|
constexpr int values_per_reduce = group_size / simd_size;
|
||||||
|
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||||
|
constexpr int writes_per_pack =
|
||||||
|
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
group_size % simd_size == 0,
|
||||||
|
"Group size must be divisible by simd size.");
|
||||||
|
|
||||||
|
int in_index = index * values_per_reduce;
|
||||||
|
int out_index = index * writes_per_pack;
|
||||||
|
|
||||||
|
T w_thread[values_per_reduce];
|
||||||
|
T w_min = Limits<T>::max;
|
||||||
|
T w_max = 0;
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < values_per_reduce; i++) {
|
||||||
|
T val = w[in_index + i];
|
||||||
|
w_thread[i] = val;
|
||||||
|
w_min = min(w_min, val);
|
||||||
|
w_max = max(w_max, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
w_min = simd_min(w_min);
|
||||||
|
w_max = simd_max(w_max);
|
||||||
|
|
||||||
|
T scale = max((w_max - w_min) / n_bins, eps);
|
||||||
|
bool side = abs(w_min) > abs(w_max);
|
||||||
|
scale = side ? scale : -scale;
|
||||||
|
T edge = side ? w_min : w_max;
|
||||||
|
T q0 = round(edge / scale);
|
||||||
|
bool at_zero = q0 == 0.0f;
|
||||||
|
scale = at_zero ? scale : edge / q0;
|
||||||
|
T bias = at_zero ? T(0) : edge;
|
||||||
|
|
||||||
|
// Write out the scales and biases
|
||||||
|
int gindex = in_index / group_size;
|
||||||
|
if (in_index % group_size == 0) {
|
||||||
|
scales[gindex] = scale;
|
||||||
|
biases[gindex] = bias;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t output = 0;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < values_per_reduce; i++) {
|
||||||
|
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
||||||
|
if (bits == 8) {
|
||||||
|
output = val;
|
||||||
|
} else {
|
||||||
|
output += val << (bits * (i % packs_per_int));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (packs_per_int < values_per_reduce &&
|
||||||
|
i % packs_per_int == packs_per_int - 1) {
|
||||||
|
out[out_index + i / packs_per_int] = output;
|
||||||
|
output = 0;
|
||||||
|
} else {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j = 0; j < writes_per_reduce - 1; j++) {
|
||||||
|
uint8_t sval = simd_shuffle_down(val, j + 1);
|
||||||
|
output += sval << (bits * (values_per_reduce + j + i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
||||||
|
out[out_index / writes_per_reduce] = output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void affine_quantize_scales_biases(
|
||||||
|
const device T* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
device uint8_t* out [[buffer(3)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
constexpr int uint8_bits = 8;
|
||||||
|
constexpr int packs_per_int = uint8_bits / bits;
|
||||||
|
constexpr T n_bins = (1 << bits) - 1;
|
||||||
|
|
||||||
|
int in_index = index * packs_per_int;
|
||||||
|
int gindex = in_index / group_size;
|
||||||
|
T scale = scales[gindex];
|
||||||
|
T bias = biases[gindex];
|
||||||
|
|
||||||
|
uint8_t output = 0;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < packs_per_int; i++) {
|
||||||
|
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
|
||||||
|
if (bits == 8) {
|
||||||
|
output = val;
|
||||||
|
} else {
|
||||||
|
output += val << (bits * i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out[index] = output;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void affine_dequantize(
|
||||||
|
const device uint8_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
device T* out [[buffer(3)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
constexpr int uint8_bits = 8;
|
||||||
|
constexpr int packs_per_int = uint8_bits / bits;
|
||||||
|
|
||||||
|
int oindex = index * packs_per_int;
|
||||||
|
int gindex = oindex / group_size;
|
||||||
|
T scale = scales[gindex];
|
||||||
|
T bias = biases[gindex];
|
||||||
|
uint val = w[index];
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < packs_per_int; i++) {
|
||||||
|
uint8_t d;
|
||||||
|
if (bits == 2) {
|
||||||
|
d = (val >> (bits * i)) & 0x03;
|
||||||
|
} else if (bits == 4) {
|
||||||
|
d = (val >> (bits * i)) & 0x0f;
|
||||||
|
} else if (bits == 8) {
|
||||||
|
d = val;
|
||||||
|
}
|
||||||
|
out[oindex + i] = scale * d + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,241 +5,67 @@
|
|||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
#include "mlx/backend/metal/kernels/quantized.h"
|
#include "mlx/backend/metal/kernels/quantized.h"
|
||||||
|
|
||||||
|
#define instantiate_quantized(name, type, group_size, bits) \
|
||||||
#define instantiate_qmv_fast(itype, group_size, bits) \
|
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
qmv_fast, \
|
name, \
|
||||||
itype, \
|
type, \
|
||||||
group_size, \
|
group_size, \
|
||||||
bits)
|
bits)
|
||||||
|
|
||||||
#define instantiate_qmv_fast_types(group_size, bits) \
|
#define instantiate_quantized_types(name, group_size, bits) \
|
||||||
instantiate_qmv_fast(float, group_size, bits) \
|
instantiate_quantized(name, float, group_size, bits) \
|
||||||
instantiate_qmv_fast(float16_t, group_size, bits) \
|
instantiate_quantized(name, float16_t, group_size, bits) \
|
||||||
instantiate_qmv_fast(bfloat16_t, group_size, bits)
|
instantiate_quantized(name, bfloat16_t, group_size, bits)
|
||||||
|
|
||||||
instantiate_qmv_fast_types(128, 2)
|
#define instantiate_quantized_groups(name, bits) \
|
||||||
instantiate_qmv_fast_types(128, 4)
|
instantiate_quantized_types(name, 128, bits) \
|
||||||
instantiate_qmv_fast_types(128, 8)
|
instantiate_quantized_types(name, 64, bits) \
|
||||||
instantiate_qmv_fast_types( 64, 2)
|
instantiate_quantized_types(name, 32, bits)
|
||||||
instantiate_qmv_fast_types( 64, 4)
|
|
||||||
instantiate_qmv_fast_types( 64, 8)
|
|
||||||
instantiate_qmv_fast_types( 32, 2)
|
|
||||||
instantiate_qmv_fast_types( 32, 4)
|
|
||||||
instantiate_qmv_fast_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_qmv(itype, group_size, bits) \
|
#define instantiate_quantized_all(name) \
|
||||||
|
instantiate_quantized_groups(name, 2) \
|
||||||
|
instantiate_quantized_groups(name, 4) \
|
||||||
|
instantiate_quantized_groups(name, 8)
|
||||||
|
|
||||||
|
instantiate_quantized_all(qmv_fast)
|
||||||
|
instantiate_quantized_all(qmv)
|
||||||
|
instantiate_quantized_all(qvm)
|
||||||
|
instantiate_quantized_all(qmm_n)
|
||||||
|
instantiate_quantized_all(bs_qmv_fast)
|
||||||
|
instantiate_quantized_all(bs_qmv)
|
||||||
|
instantiate_quantized_all(bs_qvm)
|
||||||
|
instantiate_quantized_all(bs_qmm_n)
|
||||||
|
instantiate_quantized_all(affine_quantize)
|
||||||
|
instantiate_quantized_all(affine_quantize_scales_biases)
|
||||||
|
instantiate_quantized_all(affine_dequantize)
|
||||||
|
|
||||||
|
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits, \
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
|
||||||
qmv, \
|
name, \
|
||||||
itype, \
|
type, \
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_qmv_types(group_size, bits) \
|
|
||||||
instantiate_qmv(float, group_size, bits) \
|
|
||||||
instantiate_qmv(float16_t, group_size, bits) \
|
|
||||||
instantiate_qmv(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_qmv_types(128, 2)
|
|
||||||
instantiate_qmv_types(128, 4)
|
|
||||||
instantiate_qmv_types(128, 8)
|
|
||||||
instantiate_qmv_types( 64, 2)
|
|
||||||
instantiate_qmv_types( 64, 4)
|
|
||||||
instantiate_qmv_types( 64, 8)
|
|
||||||
instantiate_qmv_types( 32, 2)
|
|
||||||
instantiate_qmv_types( 32, 4)
|
|
||||||
instantiate_qmv_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_qvm(itype, group_size, bits) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"qvm_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
||||||
qvm, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_qvm_types(group_size, bits) \
|
|
||||||
instantiate_qvm(float, group_size, bits) \
|
|
||||||
instantiate_qvm(float16_t, group_size, bits) \
|
|
||||||
instantiate_qvm(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_qvm_types(128, 2)
|
|
||||||
instantiate_qvm_types(128, 4)
|
|
||||||
instantiate_qvm_types(128, 8)
|
|
||||||
instantiate_qvm_types( 64, 2)
|
|
||||||
instantiate_qvm_types( 64, 4)
|
|
||||||
instantiate_qvm_types( 64, 8)
|
|
||||||
instantiate_qvm_types( 32, 2)
|
|
||||||
instantiate_qvm_types( 32, 4)
|
|
||||||
instantiate_qvm_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
|
|
||||||
qmm_t, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
group_size, \
|
||||||
bits, \
|
bits, \
|
||||||
aligned_N)
|
aligned)
|
||||||
|
|
||||||
#define instantiate_qmm_t_types(group_size, bits) \
|
#define instantiate_quantized_types_aligned(name, group_size, bits) \
|
||||||
instantiate_qmm_t(float, group_size, bits, false) \
|
instantiate_quantized_aligned(name, float, group_size, bits, true) \
|
||||||
instantiate_qmm_t(float16_t, group_size, bits, false) \
|
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
|
||||||
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
|
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
|
||||||
instantiate_qmm_t(float, group_size, bits, true) \
|
instantiate_quantized_aligned(name, float, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float16_t, group_size, bits, true) \
|
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
|
||||||
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
|
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
|
||||||
|
|
||||||
instantiate_qmm_t_types(128, 2)
|
#define instantiate_quantized_groups_aligned(name, bits) \
|
||||||
instantiate_qmm_t_types(128, 4)
|
instantiate_quantized_types_aligned(name, 128, bits) \
|
||||||
instantiate_qmm_t_types(128, 8)
|
instantiate_quantized_types_aligned(name, 64, bits) \
|
||||||
instantiate_qmm_t_types( 64, 2)
|
instantiate_quantized_types_aligned(name, 32, bits)
|
||||||
instantiate_qmm_t_types( 64, 4)
|
|
||||||
instantiate_qmm_t_types( 64, 8)
|
|
||||||
instantiate_qmm_t_types( 32, 2)
|
|
||||||
instantiate_qmm_t_types( 32, 4)
|
|
||||||
instantiate_qmm_t_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_qmm_n(itype, group_size, bits) \
|
#define instantiate_quantized_all_aligned(name) \
|
||||||
instantiate_kernel( \
|
instantiate_quantized_groups_aligned(name, 2) \
|
||||||
"qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
|
instantiate_quantized_groups_aligned(name, 4) \
|
||||||
qmm_n, \
|
instantiate_quantized_groups_aligned(name, 8) \
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_qmm_n_types(group_size, bits) \
|
instantiate_quantized_all_aligned(qmm_t)
|
||||||
instantiate_qmm_n(float, group_size, bits) \
|
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
|
||||||
instantiate_qmm_n(float16_t, group_size, bits) \
|
|
||||||
instantiate_qmm_n(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_qmm_n_types(128, 2)
|
|
||||||
instantiate_qmm_n_types(128, 4)
|
|
||||||
instantiate_qmm_n_types(128, 8)
|
|
||||||
instantiate_qmm_n_types( 64, 2)
|
|
||||||
instantiate_qmm_n_types( 64, 4)
|
|
||||||
instantiate_qmm_n_types( 64, 8)
|
|
||||||
instantiate_qmm_n_types( 32, 2)
|
|
||||||
instantiate_qmm_n_types( 32, 4)
|
|
||||||
instantiate_qmm_n_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmv_fast(itype, group_size, bits) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
|
||||||
bs_qmv_fast, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmv_fast_types(group_size, bits) \
|
|
||||||
instantiate_bs_qmv_fast(float, group_size, bits) \
|
|
||||||
instantiate_bs_qmv_fast(float16_t, group_size, bits) \
|
|
||||||
instantiate_bs_qmv_fast(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_bs_qmv_fast_types(128, 2)
|
|
||||||
instantiate_bs_qmv_fast_types(128, 4)
|
|
||||||
instantiate_bs_qmv_fast_types(128, 8)
|
|
||||||
instantiate_bs_qmv_fast_types( 64, 2)
|
|
||||||
instantiate_bs_qmv_fast_types( 64, 4)
|
|
||||||
instantiate_bs_qmv_fast_types( 64, 8)
|
|
||||||
instantiate_bs_qmv_fast_types( 32, 2)
|
|
||||||
instantiate_bs_qmv_fast_types( 32, 4)
|
|
||||||
instantiate_bs_qmv_fast_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmv(itype, group_size, bits) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
||||||
bs_qmv, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmv_types(group_size, bits) \
|
|
||||||
instantiate_bs_qmv(float, group_size, bits) \
|
|
||||||
instantiate_bs_qmv(float16_t, group_size, bits) \
|
|
||||||
instantiate_bs_qmv(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_bs_qmv_types(128, 2)
|
|
||||||
instantiate_bs_qmv_types(128, 4)
|
|
||||||
instantiate_bs_qmv_types(128, 8)
|
|
||||||
instantiate_bs_qmv_types( 64, 2)
|
|
||||||
instantiate_bs_qmv_types( 64, 4)
|
|
||||||
instantiate_bs_qmv_types( 64, 8)
|
|
||||||
instantiate_bs_qmv_types( 32, 2)
|
|
||||||
instantiate_bs_qmv_types( 32, 4)
|
|
||||||
instantiate_bs_qmv_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_bs_qvm(itype, group_size, bits) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
||||||
bs_qvm, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_bs_qvm_types(group_size, bits) \
|
|
||||||
instantiate_bs_qvm(float, group_size, bits) \
|
|
||||||
instantiate_bs_qvm(float16_t, group_size, bits) \
|
|
||||||
instantiate_bs_qvm(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_bs_qvm_types(128, 2)
|
|
||||||
instantiate_bs_qvm_types(128, 4)
|
|
||||||
instantiate_bs_qvm_types(128, 8)
|
|
||||||
instantiate_bs_qvm_types( 64, 2)
|
|
||||||
instantiate_bs_qvm_types( 64, 4)
|
|
||||||
instantiate_bs_qvm_types( 64, 8)
|
|
||||||
instantiate_bs_qvm_types( 32, 2)
|
|
||||||
instantiate_bs_qvm_types( 32, 4)
|
|
||||||
instantiate_bs_qvm_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
|
|
||||||
bs_qmm_t, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits, \
|
|
||||||
aligned_N)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmm_t_types(group_size, bits) \
|
|
||||||
instantiate_bs_qmm_t(float, group_size, bits, false) \
|
|
||||||
instantiate_bs_qmm_t(float16_t, group_size, bits, false) \
|
|
||||||
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \
|
|
||||||
instantiate_bs_qmm_t(float, group_size, bits, true) \
|
|
||||||
instantiate_bs_qmm_t(float16_t, group_size, bits, true) \
|
|
||||||
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true)
|
|
||||||
|
|
||||||
instantiate_bs_qmm_t_types(128, 2)
|
|
||||||
instantiate_bs_qmm_t_types(128, 4)
|
|
||||||
instantiate_bs_qmm_t_types(128, 8)
|
|
||||||
instantiate_bs_qmm_t_types( 64, 2)
|
|
||||||
instantiate_bs_qmm_t_types( 64, 4)
|
|
||||||
instantiate_bs_qmm_t_types( 64, 8)
|
|
||||||
instantiate_bs_qmm_t_types( 32, 2)
|
|
||||||
instantiate_bs_qmm_t_types( 32, 4)
|
|
||||||
instantiate_bs_qmm_t_types( 32, 8)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmm_n(itype, group_size, bits) \
|
|
||||||
instantiate_kernel( \
|
|
||||||
"bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
||||||
bs_qmm_n, \
|
|
||||||
itype, \
|
|
||||||
group_size, \
|
|
||||||
bits)
|
|
||||||
|
|
||||||
#define instantiate_bs_qmm_n_types(group_size, bits) \
|
|
||||||
instantiate_bs_qmm_n(float, group_size, bits) \
|
|
||||||
instantiate_bs_qmm_n(float16_t, group_size, bits) \
|
|
||||||
instantiate_bs_qmm_n(bfloat16_t, group_size, bits)
|
|
||||||
|
|
||||||
instantiate_bs_qmm_n_types(128, 2)
|
|
||||||
instantiate_bs_qmm_n_types(128, 4)
|
|
||||||
instantiate_bs_qmm_n_types(128, 8)
|
|
||||||
instantiate_bs_qmm_n_types( 64, 2)
|
|
||||||
instantiate_bs_qmm_n_types( 64, 4)
|
|
||||||
instantiate_bs_qmm_n_types( 64, 8)
|
|
||||||
instantiate_bs_qmm_n_types( 32, 2)
|
|
||||||
instantiate_bs_qmm_n_types( 32, 4)
|
|
||||||
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
|
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -47,8 +48,8 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||||
<< "_fast";
|
<< bits_;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -270,8 +271,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
|
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||||
<< bits_ << "_fast";
|
<< bits_;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -513,4 +514,82 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fast::AffineQuantize::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
bool compute_scale_bias = inputs.size() == 1;
|
||||||
|
|
||||||
|
auto& w_pre = inputs[0];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto w = ensure_row_contiguous(w_pre);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
if (!compute_scale_bias) {
|
||||||
|
auto& scales_pre = inputs[1];
|
||||||
|
auto& biases_pre = inputs[2];
|
||||||
|
auto scales = ensure_row_contiguous(scales_pre);
|
||||||
|
auto biases = ensure_row_contiguous(biases_pre);
|
||||||
|
compute_encoder.set_input_array(scales, 1);
|
||||||
|
compute_encoder.set_input_array(biases, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
} else {
|
||||||
|
auto& scales = outputs[1];
|
||||||
|
auto& biases = outputs[2];
|
||||||
|
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||||
|
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
compute_encoder.set_output_array(scales, 2);
|
||||||
|
compute_encoder.set_output_array(biases, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||||
|
: get_type_string(w_pre.dtype());
|
||||||
|
auto kernel_func = "affine_quantize_scales_biases";
|
||||||
|
if (dequantize_) {
|
||||||
|
kernel_func = "affine_dequantize";
|
||||||
|
} else if (compute_scale_bias) {
|
||||||
|
kernel_func = "affine_quantize";
|
||||||
|
}
|
||||||
|
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||||
|
<< bits_;
|
||||||
|
auto template_def = get_template_definition(
|
||||||
|
kname.str(), kernel_func, type_string, group_size_, bits_);
|
||||||
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Treat uint32 as uint8 in kernel
|
||||||
|
constexpr int uint8_per_uint32 = 4;
|
||||||
|
constexpr int simd_size = 32;
|
||||||
|
int packs_per_int = 8 / bits_;
|
||||||
|
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
|
||||||
|
size_t nthreads =
|
||||||
|
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
|
||||||
|
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
auto group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
auto grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -118,6 +118,7 @@ NO_GPU_MULTI(RMSNorm)
|
|||||||
NO_GPU_MULTI(RMSNormVJP)
|
NO_GPU_MULTI(RMSNormVJP)
|
||||||
NO_GPU_MULTI(RoPE)
|
NO_GPU_MULTI(RoPE)
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
249
mlx/fast.cpp
249
mlx/fast.cpp
@ -610,4 +610,253 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
|||||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array pack_and_quantize(
|
||||||
|
array& packed_w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const Stream& s) {
|
||||||
|
int el_per_int = 32 / bits;
|
||||||
|
array zero(0, packed_w.dtype());
|
||||||
|
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
||||||
|
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||||
|
packed_w = astype(
|
||||||
|
clip(
|
||||||
|
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||||
|
zero,
|
||||||
|
n_bins),
|
||||||
|
uint32);
|
||||||
|
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||||
|
packed_w = sum(
|
||||||
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
|
return packed_w;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<array, array, array>
|
||||||
|
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||||
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
|
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The requested group size " << group_size
|
||||||
|
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bits != 2 && bits != 4 && bits != 8) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The requested number of bits " << bits
|
||||||
|
<< " is not supported. The supported bits are 2, 4 and 8.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w.ndim() < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
|
<< "but it has only " << w.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((w.shape(-1) % group_size) != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||||
|
<< "the quantization group size " << group_size
|
||||||
|
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int el_per_int = 32 / bits;
|
||||||
|
|
||||||
|
if (w.shape(-1) < 32 * el_per_int) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
||||||
|
<< "too small for quantization. We support >=512 for 2 bits, "
|
||||||
|
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
|
||||||
|
<< "shape " << w.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fallback = [group_size, bits, el_per_int, s](
|
||||||
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
|
auto& w = inputs[0];
|
||||||
|
auto wshape = w.shape();
|
||||||
|
wshape.back() = -1;
|
||||||
|
|
||||||
|
array zero(0, w.dtype());
|
||||||
|
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
||||||
|
array eps(1e-7, w.dtype());
|
||||||
|
|
||||||
|
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||||
|
|
||||||
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||||
|
array scales =
|
||||||
|
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||||
|
scales = where(mask, scales, negative(scales), s);
|
||||||
|
array edge = where(mask, w_min, w_max, s);
|
||||||
|
array q0 = round(divide(edge, scales, s), s);
|
||||||
|
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||||
|
array biases = where(equal(q0, zero, s), zero, edge);
|
||||||
|
|
||||||
|
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
||||||
|
return {
|
||||||
|
reshape(packed_w, wshape, s),
|
||||||
|
reshape(scales, wshape, s),
|
||||||
|
reshape(biases, wshape, s),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<array> outputs;
|
||||||
|
if (s.device == Device::gpu) {
|
||||||
|
auto wq_shape = w.shape();
|
||||||
|
wq_shape.back() = w.shape(-1) / el_per_int;
|
||||||
|
auto sshape = w.shape();
|
||||||
|
sshape.back() = w.shape(-1) / group_size;
|
||||||
|
outputs = array::make_arrays(
|
||||||
|
{wq_shape, sshape, sshape},
|
||||||
|
{uint32, w.dtype(), w.dtype()},
|
||||||
|
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||||
|
{w});
|
||||||
|
} else {
|
||||||
|
outputs = fallback({w});
|
||||||
|
}
|
||||||
|
return {outputs[0], outputs[1], outputs[2]};
|
||||||
|
}
|
||||||
|
|
||||||
|
array affine_quantize(
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
StreamOrDevice s_) {
|
||||||
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
|
int el_per_int = 32 / bits;
|
||||||
|
auto fallback = [group_size, bits, el_per_int, s](
|
||||||
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
|
auto& w = inputs[0];
|
||||||
|
auto scales = expand_dims(inputs[1], -1, s);
|
||||||
|
auto biases = expand_dims(inputs[2], -1, s);
|
||||||
|
|
||||||
|
auto wshape = w.shape();
|
||||||
|
wshape.back() = -1;
|
||||||
|
|
||||||
|
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||||
|
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
||||||
|
return {reshape(packed_w, wshape, s)};
|
||||||
|
};
|
||||||
|
|
||||||
|
if (s.device == Device::gpu) {
|
||||||
|
auto out_shape = w.shape();
|
||||||
|
out_shape.back() = w.shape(-1) / el_per_int;
|
||||||
|
return array(
|
||||||
|
out_shape,
|
||||||
|
uint32,
|
||||||
|
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||||
|
{w, scales, biases});
|
||||||
|
}
|
||||||
|
return fallback({w, scales, biases})[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
array affine_dequantize(
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
StreamOrDevice s_) {
|
||||||
|
if (bits <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (group_size <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
|
<< "but it has only " << w.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto wshape = w.shape();
|
||||||
|
auto sshape = scales.shape();
|
||||||
|
auto bshape = biases.shape();
|
||||||
|
wshape.back() = -1;
|
||||||
|
sshape.back() = -1;
|
||||||
|
bshape.back() = -1;
|
||||||
|
|
||||||
|
if (wshape != sshape || wshape != bshape) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w.dtype() != uint32) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[dequantize] The matrix should be given as a uint32");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Packing into uint32
|
||||||
|
int el_per_int = 32 / bits;
|
||||||
|
|
||||||
|
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||||
|
<< "given the quantization parameters. Provided matrix of shape "
|
||||||
|
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||||
|
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
|
auto fallback =
|
||||||
|
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
|
||||||
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
|
auto& w = inputs[0];
|
||||||
|
auto& scales = inputs[1];
|
||||||
|
auto& biases = inputs[2];
|
||||||
|
std::vector<array> parts;
|
||||||
|
for (int start = 0; start < 32; start += bits) {
|
||||||
|
int shift_left = 32 - (start + bits);
|
||||||
|
int shift_right = shift_left + start;
|
||||||
|
|
||||||
|
parts.push_back(expand_dims(
|
||||||
|
right_shift(
|
||||||
|
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||||
|
array(32 - bits, uint32),
|
||||||
|
s),
|
||||||
|
-1,
|
||||||
|
s));
|
||||||
|
}
|
||||||
|
array w_full = concatenate(parts, -1, s);
|
||||||
|
|
||||||
|
// Dequantize
|
||||||
|
wshape.push_back(group_size);
|
||||||
|
w_full = reshape(w_full, wshape, s);
|
||||||
|
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||||
|
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||||
|
w_full = reshape(w_full, sshape, s);
|
||||||
|
|
||||||
|
return {w_full};
|
||||||
|
};
|
||||||
|
|
||||||
|
if (s.device == Device::gpu) {
|
||||||
|
auto out_shape = w.shape();
|
||||||
|
out_shape.back() = w.shape(-1) * el_per_int;
|
||||||
|
return array(
|
||||||
|
out_shape,
|
||||||
|
scales.dtype(),
|
||||||
|
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
||||||
|
{w, scales, biases});
|
||||||
|
}
|
||||||
|
return fallback({w, scales, biases})[0];
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
22
mlx/fast.h
22
mlx/fast.h
@ -39,4 +39,26 @@ array scaled_dot_product_attention(
|
|||||||
const std::optional<array>& mask = std::nullopt,
|
const std::optional<array>& mask = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
std::tuple<array, array, array> affine_quantize(
|
||||||
|
const array& w,
|
||||||
|
int group_size = 64,
|
||||||
|
int bits = 4,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array affine_quantize(
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int group_size = 64,
|
||||||
|
int bits = 4,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array affine_dequantize(
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int group_size = 64,
|
||||||
|
int bits = 4,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -212,4 +212,34 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
bool needs_mask_;
|
bool needs_mask_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class AffineQuantize : public Custom {
|
||||||
|
public:
|
||||||
|
explicit AffineQuantize(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool dequantize)
|
||||||
|
: Custom(stream, fallback),
|
||||||
|
group_size_(group_size),
|
||||||
|
bits_(bits),
|
||||||
|
dequantize_(dequantize) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override {
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(AffineQuantize);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
int group_size_;
|
||||||
|
int bits_;
|
||||||
|
bool dequantize_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
156
mlx/ops.cpp
156
mlx/ops.cpp
@ -6,6 +6,7 @@
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/fast.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
@ -3356,89 +3357,7 @@ std::tuple<array, array, array> quantize(
|
|||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
return fast::affine_quantize(w, group_size, bits);
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The requested group size " << group_size
|
|
||||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bits != 2 && bits != 4 && bits != 8) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The requested number of bits " << bits
|
|
||||||
<< " is not supported. The supported bits are 2, 4 and 8.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (w.ndim() < 2) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
|
||||||
<< "but it has only " << w.ndim() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((w.shape(-1) % group_size) != 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
|
||||||
<< "the quantization group size " << group_size
|
|
||||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute some constants used for the quantization
|
|
||||||
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
|
||||||
array eps(1e-7, w.dtype());
|
|
||||||
array zero(0, w.dtype());
|
|
||||||
int el_per_int = 32 / bits;
|
|
||||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
|
||||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
|
||||||
|
|
||||||
// Check that the w matrix will fill up a whole SIMD.
|
|
||||||
// This is an implementation detail which should be removed in the future but
|
|
||||||
// at least we bail out early which will result in a nice readable error.
|
|
||||||
//
|
|
||||||
// Hopefully nobody is quantizing matrices that small anyway.
|
|
||||||
if (w.shape(-1) < 32 * el_per_int) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
|
||||||
<< "too small for quantization. We support >=512 for 2 bits, "
|
|
||||||
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
|
|
||||||
<< "shape " << w.shape() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the shape for the outputs.
|
|
||||||
auto wshape = w.shape();
|
|
||||||
wshape.back() = -1;
|
|
||||||
|
|
||||||
// Compute scales and biases
|
|
||||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
|
||||||
|
|
||||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
|
||||||
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
|
||||||
scales = where(mask, scales, negative(scales), s);
|
|
||||||
array edge = where(mask, w_min, w_max, s);
|
|
||||||
array q0 = round(divide(edge, scales, s), s);
|
|
||||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
|
||||||
array biases = where(equal(q0, zero, s), zero, edge);
|
|
||||||
|
|
||||||
// Quantize and pack w
|
|
||||||
packed_w = astype(
|
|
||||||
clip(
|
|
||||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
|
||||||
zero,
|
|
||||||
n_bins),
|
|
||||||
uint32);
|
|
||||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
|
||||||
packed_w = sum(
|
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
|
||||||
|
|
||||||
return std::make_tuple(
|
|
||||||
reshape(packed_w, wshape, s),
|
|
||||||
reshape(scales, wshape, s),
|
|
||||||
reshape(biases, wshape, s));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
@ -3448,76 +3367,7 @@ array dequantize(
|
|||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (bits <= 0) {
|
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (group_size <= 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
|
||||||
<< "but it has only " << w.ndim() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto wshape = w.shape();
|
|
||||||
auto sshape = scales.shape();
|
|
||||||
auto bshape = biases.shape();
|
|
||||||
wshape.back() = -1;
|
|
||||||
sshape.back() = -1;
|
|
||||||
bshape.back() = -1;
|
|
||||||
|
|
||||||
if (wshape != sshape || wshape != bshape) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (w.dtype() != uint32) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[dequantize] The matrix should be given as a uint32");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute some constants for the dequantization
|
|
||||||
int el_per_int = 32 / bits;
|
|
||||||
|
|
||||||
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
|
||||||
<< "given the quantization parameters. Provided matrix of shape "
|
|
||||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
|
||||||
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the pieces from the passed quantized matrix
|
|
||||||
std::vector<array> parts;
|
|
||||||
for (int start = 0; start < 32; start += bits) {
|
|
||||||
int shift_left = 32 - (start + bits);
|
|
||||||
int shift_right = shift_left + start;
|
|
||||||
|
|
||||||
parts.push_back(expand_dims(
|
|
||||||
right_shift(
|
|
||||||
left_shift(w, array(32 - (start + bits), uint32), s),
|
|
||||||
array(32 - bits, uint32),
|
|
||||||
s),
|
|
||||||
-1,
|
|
||||||
s));
|
|
||||||
}
|
|
||||||
array w_full = concatenate(parts, -1, s);
|
|
||||||
|
|
||||||
// Dequantize
|
|
||||||
wshape.push_back(group_size);
|
|
||||||
w_full = reshape(w_full, wshape, s);
|
|
||||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
|
||||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
|
||||||
w_full = reshape(w_full, sshape, s);
|
|
||||||
|
|
||||||
return w_full;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array gather_qmm(
|
array gather_qmm(
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
|
#include <nanobind/stl/tuple.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
|
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast.h"
|
||||||
@ -138,4 +139,47 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The output array.
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"affine_quantize",
|
||||||
|
nb::overload_cast<
|
||||||
|
const array&,
|
||||||
|
const array&,
|
||||||
|
const array&,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
StreamOrDevice>(&fast::affine_quantize),
|
||||||
|
"w"_a,
|
||||||
|
"scales"_a,
|
||||||
|
"biases"_a,
|
||||||
|
"group_size"_a = 64,
|
||||||
|
"bits"_a = 4,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Quantize the matrix ``w`` using the provided ``scales`` and
|
||||||
|
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||||
|
|
||||||
|
Formally, given the notation in :func:`quantize`, we compute
|
||||||
|
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||||
|
:math:`\beta` as follows
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
w_i = s (\hat{w_i} + \beta)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w (array): Matrix to be quantize
|
||||||
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||||
|
scale and bias. (default: ``64``)
|
||||||
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
|
``w``. (default: ``4``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The quantized version of ``w``
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -439,6 +439,18 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)(x)
|
)(x)
|
||||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||||
|
|
||||||
|
def test_affine_quantize(self):
|
||||||
|
mx.random.seed(7)
|
||||||
|
x = mx.random.uniform(shape=(4, 1024))
|
||||||
|
for bits in (2, 4, 8):
|
||||||
|
for group_size in (32, 64, 128):
|
||||||
|
with self.subTest(bits=bits, group_size=group_size):
|
||||||
|
w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size)
|
||||||
|
w_p = mx.fast.affine_quantize(
|
||||||
|
x, scales, biases, bits=bits, group_size=group_size
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(w, w_p))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -12,7 +12,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w = mx.random.normal(shape=(128, 512))
|
w = mx.random.normal(shape=(128, 512))
|
||||||
for gs in [32, 64, 128]:
|
for gs in [32, 64, 128]:
|
||||||
for b in [2, 4, 8]:
|
for b in [2, 4, 8]:
|
||||||
w_q, scales, biases = mx.quantize(w, gs, b)
|
with self.subTest(gs=gs, b=b):
|
||||||
|
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
|
Loading…
Reference in New Issue
Block a user