use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-10-31 11:49:59 -07:00
committed by GitHub
parent d9e6349657
commit 39b04ce638
3 changed files with 76 additions and 144 deletions

View File

@@ -49,7 +49,10 @@ struct fp4_e2m1 {
}
operator float() {
return FP4_LUT[bits];
half converted = as_type<half>(ushort((bits & 7) << 9));
converted *= 16384.0;
converted = bits & 8 ? -converted : converted;
return converted;
}
uint8_t bits;

View File

@@ -1,12 +1,5 @@
#pragma once
inline float fp32_from_bits(uint32_t bits) {
return *(reinterpret_cast<thread float*>(&bits));
}
inline float fp32_to_bits(float x) {
return *(reinterpret_cast<thread uint32_t*>(&x));
}
struct fp8_e4m3 {
template <typename T>
fp8_e4m3(T f) {
@@ -14,7 +7,7 @@ struct fp8_e4m3 {
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
uint32_t fp8_max = 543 << 21;
uint32_t denorm_mask = 141 << 23;
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
uint32_t sign = f_bits & 0x80000000;
f_bits ^= sign;
if (f_bits >= fp8_max) {
@@ -22,8 +15,8 @@ struct fp8_e4m3 {
bits = 0x7E;
} else {
if (f_bits < (121 << 23)) {
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
f_bits = as_type<uint32_t>(
as_type<float>(f_bits) + as_type<float>(denorm_mask));
bits = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
@@ -53,7 +46,7 @@ struct fp8_e4m3 {
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
return as_type<float>(result);
}
uint8_t bits;
@@ -77,11 +70,12 @@ struct fp8_e8m0 {
bits = static_cast<uint8_t>(n + 127);
}
operator float() {
if (bits == 0xFF) {
return metal::numeric_limits<float>::quiet_NaN();
operator bfloat16_t() {
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
return as_type<bfloat16_t>(out);
}
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
operator float() {
return static_cast<float>(this->operator bfloat16_t());
}
uint8_t bits;

View File

@@ -29,15 +29,31 @@ inline constexpr short get_bytes_per_pack() {
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
return T(*(thread fp8_e8m0*)(&s));
}
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
}
}
};
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));
}
}
};
template <typename T, typename U, int values_per_thread>
inline void load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) {
@@ -62,62 +78,41 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
}
template <typename T>
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <typename U, int values_per_thread>
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut) {
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline U qdot_safe(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut,
int N) {
inline U
qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline void qouter(
const thread uint8_t* w,
U x,
U scale,
thread U* result,
const threadgroup U* lut) {
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * lut[w[i] & 0xf];
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
result[2 * i] += x * scale * Dequantize<4>{}(w[i]);
result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4);
}
}
@@ -192,7 +187,10 @@ struct QuantizedBlockLoader {
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
lut(lut_) {
load_fp4_lut(lut, simd_group_id, simd_lane_id);
if (simd_group_id == 0 && simd_lane_id < 16) {
lut[simd_lane_id] = static_cast<T>(FP4_LUT[simd_lane_id]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
void load_unsafe() const {
@@ -264,10 +262,7 @@ METAL_FUNC void fp_qmv_quad_impl(
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 8;
constexpr int values_per_thread = D / QUAD_SIZE;
@@ -279,7 +274,6 @@ METAL_FUNC void fp_qmv_quad_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
@@ -299,7 +293,7 @@ METAL_FUNC void fp_qmv_quad_impl(
U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
}
@@ -321,8 +315,7 @@ METAL_FUNC void fp_qmv_fast_impl(
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
@@ -337,7 +330,6 @@ METAL_FUNC void fp_qmv_fast_impl(
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -358,7 +350,7 @@ METAL_FUNC void fp_qmv_fast_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -384,8 +376,7 @@ METAL_FUNC void fp_qmv_impl(
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
@@ -402,7 +393,6 @@ METAL_FUNC void fp_qmv_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -433,7 +423,7 @@ METAL_FUNC void fp_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
uint8_t s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -452,7 +442,7 @@ METAL_FUNC void fp_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
}
@@ -481,7 +471,7 @@ METAL_FUNC void fp_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -501,7 +491,7 @@ METAL_FUNC void fp_qmv_impl(
U s = dequantize_scale<U>(sl[0]);
result[row] +=
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
@@ -523,8 +513,7 @@ METAL_FUNC void fp_qvm_impl(
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack();
@@ -545,8 +534,6 @@ METAL_FUNC void fp_qvm_impl(
thread U scale = 0;
thread U x_local = 0;
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
@@ -568,7 +555,7 @@ METAL_FUNC void fp_qvm_impl(
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
x += block_size;
scales += block_size * out_vec_size_g;
@@ -581,7 +568,7 @@ METAL_FUNC void fp_qvm_impl(
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
x += block_size;
scales += block_size * out_vec_size_g;
@@ -596,7 +583,7 @@ METAL_FUNC void fp_qvm_impl(
scale = 0;
}
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
}
// Accumulate in the simdgroup
@@ -975,9 +962,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
@@ -995,20 +980,8 @@ template <typename T, int group_size, int bits, int D, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
fp_qmv_quad_impl<T, group_size, bits, D>(
w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid,
simd_gid,
simd_lid,
lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
@@ -1046,9 +1019,8 @@ template <typename T, int group_size, int bits, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, int bits, bool batched>
@@ -1086,9 +1058,8 @@ template <typename T, const int group_size, int bits, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, int bits, bool batched>
@@ -1126,9 +1097,8 @@ template <typename T, const int group_size, int bits, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, int bits, int split_k = 32>
@@ -1170,18 +1140,8 @@ template <typename T, const int group_size, int bits, int split_k = 32>
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
fp_qvm_impl<T, group_size, bits>(
w,
scales,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid,
lut);
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
}
template <
@@ -1342,9 +1302,8 @@ template <typename T, int group_size, int bits>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
@@ -1392,9 +1351,8 @@ template <typename T, int group_size, int bits>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
@@ -1442,9 +1400,8 @@ template <typename T, int group_size, int bits>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <
@@ -1771,28 +1728,6 @@ template <
}
}
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
}
}
};
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));
}
}
};
template <typename T, const int group_size, const int bits>
[[kernel]] void fp_quantize(
const device T* w [[buffer(0)]],