use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-10-31 11:49:59 -07:00
committed by GitHub
parent d9e6349657
commit 39b04ce638
3 changed files with 76 additions and 144 deletions

View File

@@ -49,7 +49,10 @@ struct fp4_e2m1 {
} }
operator float() { operator float() {
return FP4_LUT[bits]; half converted = as_type<half>(ushort((bits & 7) << 9));
converted *= 16384.0;
converted = bits & 8 ? -converted : converted;
return converted;
} }
uint8_t bits; uint8_t bits;

View File

@@ -1,12 +1,5 @@
#pragma once #pragma once
inline float fp32_from_bits(uint32_t bits) {
return *(reinterpret_cast<thread float*>(&bits));
}
inline float fp32_to_bits(float x) {
return *(reinterpret_cast<thread uint32_t*>(&x));
}
struct fp8_e4m3 { struct fp8_e4m3 {
template <typename T> template <typename T>
fp8_e4m3(T f) { fp8_e4m3(T f) {
@@ -14,7 +7,7 @@ struct fp8_e4m3 {
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
uint32_t fp8_max = 543 << 21; uint32_t fp8_max = 543 << 21;
uint32_t denorm_mask = 141 << 23; uint32_t denorm_mask = 141 << 23;
uint32_t f_bits = fp32_to_bits(static_cast<float>(f)); uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
uint32_t sign = f_bits & 0x80000000; uint32_t sign = f_bits & 0x80000000;
f_bits ^= sign; f_bits ^= sign;
if (f_bits >= fp8_max) { if (f_bits >= fp8_max) {
@@ -22,8 +15,8 @@ struct fp8_e4m3 {
bits = 0x7E; bits = 0x7E;
} else { } else {
if (f_bits < (121 << 23)) { if (f_bits < (121 << 23)) {
f_bits = f_bits = as_type<uint32_t>(
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); as_type<float>(f_bits) + as_type<float>(denorm_mask));
bits = static_cast<uint8_t>(f_bits - denorm_mask); bits = static_cast<uint8_t>(f_bits - denorm_mask);
} else { } else {
// resulting mantissa is odd // resulting mantissa is odd
@@ -53,7 +46,7 @@ struct fp8_e4m3 {
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) & inf_nan_mask) &
~zero_mask); ~zero_mask);
return fp32_from_bits(result); return as_type<float>(result);
} }
uint8_t bits; uint8_t bits;
@@ -77,11 +70,12 @@ struct fp8_e8m0 {
bits = static_cast<uint8_t>(n + 127); bits = static_cast<uint8_t>(n + 127);
} }
operator bfloat16_t() {
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
return as_type<bfloat16_t>(out);
}
operator float() { operator float() {
if (bits == 0xFF) { return static_cast<float>(this->operator bfloat16_t());
return metal::numeric_limits<float>::quiet_NaN();
}
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
} }
uint8_t bits; uint8_t bits;

View File

@@ -29,15 +29,31 @@ inline constexpr short get_bytes_per_pack() {
template <typename T> template <typename T>
static inline T dequantize_scale(uint8_t s) { static inline T dequantize_scale(uint8_t s) {
using FOrI = union { return T(*(thread fp8_e8m0*)(&s));
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
} }
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
}
}
};
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));
}
}
};
template <typename T, typename U, int values_per_thread> template <typename T, typename U, int values_per_thread>
inline void load_vector(const device T* x, thread U* x_thread) { inline void load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) { for (int i = 0; i < values_per_thread; i += 4) {
@@ -62,62 +78,41 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
} }
} }
template <typename T>
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <typename U, int values_per_thread> template <typename U, int values_per_thread>
inline U qdot( inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut) {
U accum = 0; U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
accum += accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
} }
return scale * accum; return scale * accum;
} }
template <typename U, int values_per_thread> template <typename U, int values_per_thread>
inline U qdot_safe( inline U
const device uint8_t* w, qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) {
const thread U* x_thread,
U scale,
const threadgroup U* lut,
int N) {
U accum = 0; U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) { for (int i = 0; i < (N / 4); i++) {
accum += accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
} }
return scale * accum; return scale * accum;
} }
template <typename U, int values_per_thread> template <typename U, int values_per_thread>
inline void qouter( inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
const thread uint8_t* w,
U x,
U scale,
thread U* result,
const threadgroup U* lut) {
for (int i = 0; i < (values_per_thread / 2); i++) { for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * lut[w[i] & 0xf]; result[2 * i] += x * scale * Dequantize<4>{}(w[i]);
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf]; result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4);
} }
} }
@@ -192,7 +187,10 @@ struct QuantizedBlockLoader {
bj * bytes_per_pack), bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size), scales(scales_ + bi * src_ld / group_size),
lut(lut_) { lut(lut_) {
load_fp4_lut(lut, simd_group_id, simd_lane_id); if (simd_group_id == 0 && simd_lane_id < 16) {
lut[simd_lane_id] = static_cast<T>(FP4_LUT[simd_lane_id]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
} }
void load_unsafe() const { void load_unsafe() const {
@@ -264,10 +262,7 @@ METAL_FUNC void fp_qmv_quad_impl(
const constant int& out_vec_size, const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) {
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 8; constexpr int pack_factor = 8;
constexpr int values_per_thread = D / QUAD_SIZE; constexpr int values_per_thread = D / QUAD_SIZE;
@@ -279,7 +274,6 @@ METAL_FUNC void fp_qmv_quad_impl(
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0}; thread U result[results_per_quadgroup] = {0};
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
@@ -299,7 +293,7 @@ METAL_FUNC void fp_qmv_quad_impl(
U s = dequantize_scale<U>(sl[0]); U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) { if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut); result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
} }
} }
@@ -321,8 +315,7 @@ METAL_FUNC void fp_qmv_fast_impl(
const constant int& out_vec_size, const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]], uint simd_lid [[thread_index_in_simdgroup]]) {
threadgroup float* lut) {
constexpr int packs_per_thread = 2; constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
@@ -337,7 +330,6 @@ METAL_FUNC void fp_qmv_fast_impl(
typedef float U; typedef float U;
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -358,7 +350,7 @@ METAL_FUNC void fp_qmv_fast_impl(
const device auto* sl = scales + row * in_vec_size_g; const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]); U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut); result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
} }
ws += block_size * bytes_per_pack / pack_factor; ws += block_size * bytes_per_pack / pack_factor;
@@ -384,8 +376,7 @@ METAL_FUNC void fp_qmv_impl(
const constant int& out_vec_size, const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]], uint simd_lid [[thread_index_in_simdgroup]]) {
threadgroup float* lut) {
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1; constexpr int packs_per_thread = 1;
@@ -402,7 +393,6 @@ METAL_FUNC void fp_qmv_impl(
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -433,7 +423,7 @@ METAL_FUNC void fp_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g; const device auto* sl = scales + row * in_vec_size_g;
uint8_t s = sl[0]; uint8_t s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut); result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
} }
ws += block_size * bytes_per_pack / pack_factor; ws += block_size * bytes_per_pack / pack_factor;
@@ -452,7 +442,7 @@ METAL_FUNC void fp_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g; const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]); U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut); result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
} }
} }
@@ -481,7 +471,7 @@ METAL_FUNC void fp_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g; const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]); U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut); result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
} }
ws += block_size * bytes_per_pack / pack_factor; ws += block_size * bytes_per_pack / pack_factor;
@@ -501,7 +491,7 @@ METAL_FUNC void fp_qmv_impl(
U s = dequantize_scale<U>(sl[0]); U s = dequantize_scale<U>(sl[0]);
result[row] += result[row] +=
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining); qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
} }
} }
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
@@ -523,8 +513,7 @@ METAL_FUNC void fp_qvm_impl(
const int out_vec_size, const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]], uint simd_lid [[thread_index_in_simdgroup]]) {
threadgroup float* lut) {
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int pack_factor = get_pack_factor<32>(); constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int bytes_per_pack = get_bytes_per_pack();
@@ -545,8 +534,6 @@ METAL_FUNC void fp_qvm_impl(
thread U scale = 0; thread U scale = 0;
thread U x_local = 0; thread U x_local = 0;
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size; const int out_vec_size_g = out_vec_size / group_size;
@@ -568,7 +555,7 @@ METAL_FUNC void fp_qvm_impl(
scale = dequantize_scale<U>(*scales); scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws); w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>( qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut); (thread uint8_t*)&w_local, x_local, scale, result);
x += block_size; x += block_size;
scales += block_size * out_vec_size_g; scales += block_size * out_vec_size_g;
@@ -581,7 +568,7 @@ METAL_FUNC void fp_qvm_impl(
w_local = *((device vec_w*)ws); w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>( qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut); (thread uint8_t*)&w_local, x_local, scale, result);
x += block_size; x += block_size;
scales += block_size * out_vec_size_g; scales += block_size * out_vec_size_g;
@@ -596,7 +583,7 @@ METAL_FUNC void fp_qvm_impl(
scale = 0; scale = 0;
} }
qouter<U, tn * pack_factor>( qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut); (thread uint8_t*)&w_local, x_local, scale, result);
} }
// Accumulate in the simdgroup // Accumulate in the simdgroup
@@ -975,9 +962,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
const constant int64_t* s_strides, const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) {
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) { if (batched) {
int M = x_shape[x_batch_ndims]; int M = x_shape[x_batch_ndims];
adjust_matrix_offsets( adjust_matrix_offsets(
@@ -995,20 +980,8 @@ template <typename T, int group_size, int bits, int D, bool batched>
s_strides, s_strides,
tid); tid);
} }
threadgroup float lut[16];
fp_qmv_quad_impl<T, group_size, bits, D>( fp_qmv_quad_impl<T, group_size, bits, D>(
w, w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid,
simd_gid,
simd_lid,
lut);
} }
template <typename T, int group_size, int bits, bool batched> template <typename T, int group_size, int bits, bool batched>
@@ -1046,9 +1019,8 @@ template <typename T, int group_size, int bits, bool batched>
s_strides, s_strides,
tid); tid);
} }
threadgroup float lut[16];
fp_qmv_fast_impl<T, group_size, bits>( fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
} }
template <typename T, const int group_size, int bits, bool batched> template <typename T, const int group_size, int bits, bool batched>
@@ -1086,9 +1058,8 @@ template <typename T, const int group_size, int bits, bool batched>
s_strides, s_strides,
tid); tid);
} }
threadgroup float lut[16];
fp_qmv_impl<T, group_size, bits>( fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
} }
template <typename T, const int group_size, int bits, bool batched> template <typename T, const int group_size, int bits, bool batched>
@@ -1126,9 +1097,8 @@ template <typename T, const int group_size, int bits, bool batched>
s_strides, s_strides,
tid); tid);
} }
threadgroup float lut[16];
fp_qvm_impl<T, group_size, bits>( fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
} }
template <typename T, const int group_size, int bits, int split_k = 32> template <typename T, const int group_size, int bits, int split_k = 32>
@@ -1170,18 +1140,8 @@ template <typename T, const int group_size, int bits, int split_k = 32>
int in_vec_size_adj = int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
fp_qvm_impl<T, group_size, bits>( fp_qvm_impl<T, group_size, bits>(
w, w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
scales,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid,
lut);
} }
template < template <
@@ -1342,9 +1302,8 @@ template <typename T, int group_size, int bits>
w_strides, w_strides,
s_strides, s_strides,
tid); tid);
threadgroup float lut[16];
fp_qmv_fast_impl<T, group_size, bits>( fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
@@ -1392,9 +1351,8 @@ template <typename T, int group_size, int bits>
w_strides, w_strides,
s_strides, s_strides,
tid); tid);
threadgroup float lut[16];
fp_qmv_impl<T, group_size, bits>( fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
@@ -1442,9 +1400,8 @@ template <typename T, int group_size, int bits>
w_strides, w_strides,
s_strides, s_strides,
tid); tid);
threadgroup float lut[16];
fp_qvm_impl<T, group_size, bits>( fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
} }
template < template <
@@ -1771,28 +1728,6 @@ template <
} }
} }
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
}
}
};
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));
}
}
};
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void fp_quantize( [[kernel]] void fp_quantize(
const device T* w [[buffer(0)]], const device T* w [[buffer(0)]],