mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize * add full quantize kernel * fused kernel with scale/bias computation * fix docstring * fix no jit error * fix test * test fix * reduce fast api to only affine_quantize
This commit is contained in:
@@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void qmm_t_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
const int bits,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void qmm_n_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@@ -1099,7 +1099,7 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@@ -1131,7 +1131,7 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@@ -1382,7 +1382,7 @@ template <
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@@ -1450,6 +1450,147 @@ template <
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
constexpr T eps = T(1e-7);
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr T n_bins = (1 << bits) - 1;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
constexpr int values_per_reduce = group_size / simd_size;
|
||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||
constexpr int writes_per_pack =
|
||||
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
||||
|
||||
static_assert(
|
||||
group_size % simd_size == 0,
|
||||
"Group size must be divisible by simd size.");
|
||||
|
||||
int in_index = index * values_per_reduce;
|
||||
int out_index = index * writes_per_pack;
|
||||
|
||||
T w_thread[values_per_reduce];
|
||||
T w_min = Limits<T>::max;
|
||||
T w_max = 0;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
T val = w[in_index + i];
|
||||
w_thread[i] = val;
|
||||
w_min = min(w_min, val);
|
||||
w_max = max(w_max, val);
|
||||
}
|
||||
|
||||
w_min = simd_min(w_min);
|
||||
w_max = simd_max(w_max);
|
||||
|
||||
T scale = max((w_max - w_min) / n_bins, eps);
|
||||
bool side = abs(w_min) > abs(w_max);
|
||||
scale = side ? scale : -scale;
|
||||
T edge = side ? w_min : w_max;
|
||||
T q0 = round(edge / scale);
|
||||
bool at_zero = q0 == 0.0f;
|
||||
scale = at_zero ? scale : edge / q0;
|
||||
T bias = at_zero ? T(0) : edge;
|
||||
|
||||
// Write out the scales and biases
|
||||
int gindex = in_index / group_size;
|
||||
if (in_index % group_size == 0) {
|
||||
scales[gindex] = scale;
|
||||
biases[gindex] = bias;
|
||||
}
|
||||
|
||||
uint8_t output = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
||||
if (bits == 8) {
|
||||
output = val;
|
||||
} else {
|
||||
output += val << (bits * (i % packs_per_int));
|
||||
}
|
||||
|
||||
if (packs_per_int < values_per_reduce &&
|
||||
i % packs_per_int == packs_per_int - 1) {
|
||||
out[out_index + i / packs_per_int] = output;
|
||||
output = 0;
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < writes_per_reduce - 1; j++) {
|
||||
uint8_t sval = simd_shuffle_down(val, j + 1);
|
||||
output += sval << (bits * (values_per_reduce + j + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
||||
out[out_index / writes_per_reduce] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize_scales_biases(
|
||||
const device T* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
device uint8_t* out [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
constexpr T n_bins = (1 << bits) - 1;
|
||||
|
||||
int in_index = index * packs_per_int;
|
||||
int gindex = in_index / group_size;
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
|
||||
uint8_t output = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < packs_per_int; i++) {
|
||||
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
|
||||
if (bits == 8) {
|
||||
output = val;
|
||||
} else {
|
||||
output += val << (bits * i);
|
||||
}
|
||||
}
|
||||
out[index] = output;
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_dequantize(
|
||||
const device uint8_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
|
||||
int oindex = index * packs_per_int;
|
||||
int gindex = oindex / group_size;
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
uint val = w[index];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < packs_per_int; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 2) {
|
||||
d = (val >> (bits * i)) & 0x03;
|
||||
} else if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[oindex + i] = scale * d + bias;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user