Fused Affine Quantize/Dequantize ops (#1282)

* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
This commit is contained in:
Alex Barron
2024-07-29 15:11:38 -07:00
committed by GitHub
parent aa1d6cadad
commit c52d1600f0
11 changed files with 655 additions and 400 deletions

View File

@@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl(
const device T* x,
const device uint32_t* w,
@@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_n_impl(
const device T* x,
const device uint32_t* w,
@@ -1099,7 +1099,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@@ -1131,7 +1131,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@@ -1382,7 +1382,7 @@ template <
s_strides,
b_strides,
tid);
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@@ -1450,6 +1450,147 @@ template <
s_strides,
b_strides,
tid);
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
static_assert(
group_size % simd_size == 0,
"Group size must be divisible by simd size.");
int in_index = index * values_per_reduce;
int out_index = index * writes_per_pack;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
T w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
T val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
w_min = simd_min(w_min);
w_max = simd_max(w_max);
T scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
T edge = side ? w_min : w_max;
T q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
T bias = at_zero ? T(0) : edge;
// Write out the scales and biases
int gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = scale;
biases[gindex] = bias;
}
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (values_per_reduce + j + i));
}
}
}
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
int in_index = index * packs_per_int;
int gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
}
}
out[index] = output;
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_dequantize(
const device uint8_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device T* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
int oindex = index * packs_per_int;
int gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint val = w[index];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[oindex + i] = scale * d + bias;
}
}