mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
@@ -49,7 +49,10 @@ struct fp4_e2m1 {
|
||||
}
|
||||
|
||||
operator float() {
|
||||
return FP4_LUT[bits];
|
||||
half converted = as_type<half>(ushort((bits & 7) << 9));
|
||||
converted *= 16384.0;
|
||||
converted = bits & 8 ? -converted : converted;
|
||||
return converted;
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
inline float fp32_from_bits(uint32_t bits) {
|
||||
return *(reinterpret_cast<thread float*>(&bits));
|
||||
}
|
||||
inline float fp32_to_bits(float x) {
|
||||
return *(reinterpret_cast<thread uint32_t*>(&x));
|
||||
}
|
||||
|
||||
struct fp8_e4m3 {
|
||||
template <typename T>
|
||||
fp8_e4m3(T f) {
|
||||
@@ -14,7 +7,7 @@ struct fp8_e4m3 {
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
||||
uint32_t fp8_max = 543 << 21;
|
||||
uint32_t denorm_mask = 141 << 23;
|
||||
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
||||
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
|
||||
uint32_t sign = f_bits & 0x80000000;
|
||||
f_bits ^= sign;
|
||||
if (f_bits >= fp8_max) {
|
||||
@@ -22,8 +15,8 @@ struct fp8_e4m3 {
|
||||
bits = 0x7E;
|
||||
} else {
|
||||
if (f_bits < (121 << 23)) {
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
f_bits = as_type<uint32_t>(
|
||||
as_type<float>(f_bits) + as_type<float>(denorm_mask));
|
||||
bits = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
@@ -53,7 +46,7 @@ struct fp8_e4m3 {
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
return as_type<float>(result);
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
@@ -77,11 +70,12 @@ struct fp8_e8m0 {
|
||||
bits = static_cast<uint8_t>(n + 127);
|
||||
}
|
||||
|
||||
operator bfloat16_t() {
|
||||
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
|
||||
return as_type<bfloat16_t>(out);
|
||||
}
|
||||
operator float() {
|
||||
if (bits == 0xFF) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
|
||||
return static_cast<float>(this->operator bfloat16_t());
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
|
||||
@@ -29,15 +29,31 @@ inline constexpr short get_bytes_per_pack() {
|
||||
|
||||
template <typename T>
|
||||
static inline T dequantize_scale(uint8_t s) {
|
||||
using FOrI = union {
|
||||
bfloat16_t f;
|
||||
uint16_t i;
|
||||
};
|
||||
FOrI out;
|
||||
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||
return static_cast<T>(out.f);
|
||||
return T(*(thread fp8_e8m0*)(&s));
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
uint8_t operator()(float x) {
|
||||
if (bits == 8) {
|
||||
return fp8_e4m3(x).bits;
|
||||
} else {
|
||||
return fp4_e2m1(x).bits;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
float operator()(uint8_t x) {
|
||||
if (bits == 8) {
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(thread fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int values_per_thread>
|
||||
inline void load_vector(const device T* x, thread U* x_thread) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
@@ -62,62 +78,41 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
if (simd_gid == 0 && simd_lid < 16) {
|
||||
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread>
|
||||
inline U qdot(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
const threadgroup U* lut) {
|
||||
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
|
||||
U accum = 0;
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * lut[ws[i] & 0xf] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
|
||||
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
|
||||
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
|
||||
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
|
||||
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
|
||||
}
|
||||
return scale * accum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread>
|
||||
inline U qdot_safe(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
const threadgroup U* lut,
|
||||
int N) {
|
||||
inline U
|
||||
qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) {
|
||||
U accum = 0;
|
||||
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * lut[ws[i] & 0xf] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
|
||||
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
|
||||
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
|
||||
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
|
||||
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
|
||||
}
|
||||
return scale * accum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread>
|
||||
inline void qouter(
|
||||
const thread uint8_t* w,
|
||||
U x,
|
||||
U scale,
|
||||
thread U* result,
|
||||
const threadgroup U* lut) {
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
|
||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||
result[2 * i] += x * scale * lut[w[i] & 0xf];
|
||||
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
|
||||
result[2 * i] += x * scale * Dequantize<4>{}(w[i]);
|
||||
result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +187,10 @@ struct QuantizedBlockLoader {
|
||||
bj * bytes_per_pack),
|
||||
scales(scales_ + bi * src_ld / group_size),
|
||||
lut(lut_) {
|
||||
load_fp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
if (simd_group_id == 0 && simd_lane_id < 16) {
|
||||
lut[simd_lane_id] = static_cast<T>(FP4_LUT[simd_lane_id]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
void load_unsafe() const {
|
||||
@@ -264,10 +262,7 @@ METAL_FUNC void fp_qmv_quad_impl(
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
||||
constexpr int pack_factor = 8;
|
||||
constexpr int values_per_thread = D / QUAD_SIZE;
|
||||
@@ -279,7 +274,6 @@ METAL_FUNC void fp_qmv_quad_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_quadgroup] = {0};
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
@@ -299,7 +293,7 @@ METAL_FUNC void fp_qmv_quad_impl(
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,8 +315,7 @@ METAL_FUNC void fp_qmv_fast_impl(
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
@@ -337,7 +330,6 @@ METAL_FUNC void fp_qmv_fast_impl(
|
||||
typedef float U;
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -358,7 +350,7 @@ METAL_FUNC void fp_qmv_fast_impl(
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
@@ -384,8 +376,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int packs_per_thread = 1;
|
||||
@@ -402,7 +393,6 @@ METAL_FUNC void fp_qmv_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -433,7 +423,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
uint8_t s = sl[0];
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
@@ -452,7 +442,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -481,7 +471,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
@@ -501,7 +491,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] +=
|
||||
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
|
||||
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
|
||||
}
|
||||
}
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@@ -523,8 +513,7 @@ METAL_FUNC void fp_qvm_impl(
|
||||
const int out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = get_pack_factor<32>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack();
|
||||
@@ -545,8 +534,6 @@ METAL_FUNC void fp_qvm_impl(
|
||||
thread U scale = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
@@ -568,7 +555,7 @@ METAL_FUNC void fp_qvm_impl(
|
||||
scale = dequantize_scale<U>(*scales);
|
||||
w_local = *((device vec_w*)ws);
|
||||
qouter<U, tn * pack_factor>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
||||
|
||||
x += block_size;
|
||||
scales += block_size * out_vec_size_g;
|
||||
@@ -581,7 +568,7 @@ METAL_FUNC void fp_qvm_impl(
|
||||
w_local = *((device vec_w*)ws);
|
||||
|
||||
qouter<U, tn * pack_factor>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
||||
|
||||
x += block_size;
|
||||
scales += block_size * out_vec_size_g;
|
||||
@@ -596,7 +583,7 @@ METAL_FUNC void fp_qvm_impl(
|
||||
scale = 0;
|
||||
}
|
||||
qouter<U, tn * pack_factor>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
||||
}
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
@@ -975,9 +962,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
const constant int64_t* s_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets(
|
||||
@@ -995,20 +980,8 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_quad_impl<T, group_size, bits, D>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
quad_gid,
|
||||
quad_lid,
|
||||
simd_gid,
|
||||
simd_lid,
|
||||
lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
@@ -1046,9 +1019,8 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
@@ -1086,9 +1058,8 @@ template <typename T, const int group_size, int bits, bool batched>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
@@ -1126,9 +1097,8 @@ template <typename T, const int group_size, int bits, bool batched>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
@@ -1170,18 +1140,8 @@ template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
int in_vec_size_adj =
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
threadgroup float lut[16];
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size_adj,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid,
|
||||
lut);
|
||||
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -1342,9 +1302,8 @@ template <typename T, int group_size, int bits>
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
@@ -1392,9 +1351,8 @@ template <typename T, int group_size, int bits>
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
@@ -1442,9 +1400,8 @@ template <typename T, int group_size, int bits>
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -1771,28 +1728,6 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
uint8_t operator()(float x) {
|
||||
if (bits == 8) {
|
||||
return fp8_e4m3(x).bits;
|
||||
} else {
|
||||
return fp4_e2m1(x).bits;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
float operator()(uint8_t x) {
|
||||
if (bits == 8) {
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(thread fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void fp_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
|
||||
Reference in New Issue
Block a user