mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
speedup
This commit is contained in:
parent
6295e53216
commit
51449428dd
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
@ -24,6 +24,17 @@ inline constexpr short get_bytes_per_pack() {
|
|||||||
return wsize / 8;
|
return wsize / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline T dequantize_scale(uint8_t s) {
|
||||||
|
using FOrI = union {
|
||||||
|
bfloat16_t f;
|
||||||
|
uint16_t i;
|
||||||
|
};
|
||||||
|
FOrI out;
|
||||||
|
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||||
|
return static_cast<T>(out.f);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int values_per_thread>
|
template <typename T, typename U, int values_per_thread>
|
||||||
inline void load_vector(const device T* x, thread U* x_thread) {
|
inline void load_vector(const device T* x, thread U* x_thread) {
|
||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
@ -48,7 +59,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
constant float MXFP4_LUT[16] = {
|
constexpr constant static float MXFP4_LUT[16] = {
|
||||||
+0.0f,
|
+0.0f,
|
||||||
+0.5f,
|
+0.5f,
|
||||||
+1.0f,
|
+1.0f,
|
||||||
@ -66,51 +77,74 @@ constant float MXFP4_LUT[16] = {
|
|||||||
-4.0f,
|
-4.0f,
|
||||||
-6.0f};
|
-6.0f};
|
||||||
|
|
||||||
template <typename U, int values_per_thread>
|
template <typename T>
|
||||||
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
|
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||||
U accum = 0;
|
if (simd_gid == 0 && simd_lid < 16) {
|
||||||
|
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int values_per_thread>
|
||||||
|
inline U qdot(
|
||||||
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
const threadgroup U* lut) {
|
||||||
|
U accum = 0;
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum +=
|
accum +=
|
||||||
(x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] +
|
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
||||||
x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] +
|
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] +
|
||||||
x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] +
|
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] +
|
||||||
x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]);
|
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]);
|
||||||
}
|
}
|
||||||
return scale * accum;
|
return scale * accum;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, typename S>
|
template <typename U, int values_per_thread>
|
||||||
inline U
|
inline U qdot_safe(
|
||||||
qdot_safe(const device uint8_t* w, const thread U* x_thread, S scale, int N) {
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
const threadgroup U* lut,
|
||||||
|
int N) {
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
accum +=
|
accum +=
|
||||||
(x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] +
|
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
||||||
x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] +
|
x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] +
|
||||||
x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] +
|
x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] +
|
||||||
x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]);
|
x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]);
|
||||||
}
|
}
|
||||||
return scale * accum;
|
return scale * accum;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread>
|
template <typename U, int values_per_thread>
|
||||||
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
|
inline void qouter(
|
||||||
|
const thread uint8_t* w,
|
||||||
|
U x,
|
||||||
|
U scale,
|
||||||
|
thread U* result,
|
||||||
|
const threadgroup U* lut) {
|
||||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||||
result[2 * i] += x * scale * MXFP4_LUT[w[i] & 0x0f];
|
result[2 * i] += x * scale * lut[w[i] & 0x0f];
|
||||||
result[2 * i + 1] += x * scale * MXFP4_LUT[(w[i] & 0xf0) >> 4];
|
result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int N>
|
template <typename U, int N>
|
||||||
inline void
|
inline void dequantize(
|
||||||
dequantize(const device uint8_t* w, U scale, threadgroup U* w_local) {
|
const device uint8_t* w,
|
||||||
|
U scale,
|
||||||
|
threadgroup U* w_local,
|
||||||
|
const threadgroup U* lut) {
|
||||||
for (int i = 0; i < (N / 2); i++) {
|
for (int i = 0; i < (N / 2); i++) {
|
||||||
w_local[2 * i] = scale * static_cast<U>(MXFP4_LUT[w[i] & 0x0f]);
|
w_local[2 * i] = scale * lut[w[i] & 0x0f];
|
||||||
w_local[2 * i + 1] = scale * static_cast<U>(MXFP4_LUT[(w[i] & 0xf0) >> 4]);
|
w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,12 +184,14 @@ struct QuantizedBlockLoader {
|
|||||||
threadgroup T* dst;
|
threadgroup T* dst;
|
||||||
const device uint8_t* src;
|
const device uint8_t* src;
|
||||||
const device S* scales;
|
const device S* scales;
|
||||||
|
threadgroup T* lut;
|
||||||
|
|
||||||
QuantizedBlockLoader(
|
QuantizedBlockLoader(
|
||||||
const device uint8_t* src_,
|
const device uint8_t* src_,
|
||||||
const device S* scales_,
|
const device S* scales_,
|
||||||
const int src_ld_,
|
const int src_ld_,
|
||||||
threadgroup T* dst_,
|
threadgroup T* dst_,
|
||||||
|
threadgroup T* lut_,
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
: src_ld(src_ld_),
|
: src_ld(src_ld_),
|
||||||
@ -170,17 +206,20 @@ struct QuantizedBlockLoader {
|
|||||||
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
||||||
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
||||||
bj * bytes_per_pack),
|
bj * bytes_per_pack),
|
||||||
scales(scales_ + bi * src_ld / group_size) {}
|
scales(scales_ + bi * src_ld / group_size),
|
||||||
|
lut(lut_) {
|
||||||
|
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
|
||||||
|
}
|
||||||
|
|
||||||
void load_unsafe() const {
|
void load_unsafe() const {
|
||||||
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
T scale = metal::pow(T(2.0), static_cast<int>(*scales) - 127);
|
T scale = dequantize_scale<T>(*scales);
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor>(
|
dequantize<T, pack_factor>(
|
||||||
src + i * bytes_per_pack, scale, dst + i * pack_factor);
|
src + i * bytes_per_pack, scale, dst + i * pack_factor, lut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,12 +242,13 @@ struct QuantizedBlockLoader {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
T scale = metal::pow(T(2.0), static_cast<int>(*scales) - 127);
|
T scale = dequantize_scale<T>(*scales);
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor>(
|
dequantize<T, pack_factor>(
|
||||||
(device uint8_t*)(src + i * bytes_per_pack),
|
(device uint8_t*)(src + i * bytes_per_pack),
|
||||||
scale,
|
scale,
|
||||||
dst + i * pack_factor);
|
dst + i * pack_factor,
|
||||||
|
lut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,7 +280,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
|||||||
const constant int& out_vec_size,
|
const constant int& out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
uint quad_lid [[thread_index_in_quadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
threadgroup float* lut) {
|
||||||
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
||||||
constexpr int pack_factor = 8;
|
constexpr int pack_factor = 8;
|
||||||
constexpr int values_per_thread = D / QUAD_SIZE;
|
constexpr int values_per_thread = D / QUAD_SIZE;
|
||||||
@ -252,6 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
|||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_quadgroup] = {0};
|
thread U result[results_per_quadgroup] = {0};
|
||||||
|
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
@ -269,9 +313,9 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
|||||||
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||||
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||||
|
|
||||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
U s = dequantize_scale<U>(sl[0]);
|
||||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,7 +337,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
|||||||
const constant int& out_vec_size,
|
const constant int& out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
threadgroup float* lut) {
|
||||||
constexpr int packs_per_thread = 2;
|
constexpr int packs_per_thread = 2;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
@ -306,9 +351,9 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
|||||||
const device uint8_t* ws = (const device uint8_t*)w;
|
const device uint8_t* ws = (const device uint8_t*)w;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
@ -328,8 +373,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
|||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
const device auto* sl = scales + row * in_vec_size_g;
|
const device auto* sl = scales + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
U s = dequantize_scale<U>(sl[0]);
|
||||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
ws += block_size * bytes_per_pack / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
@ -355,7 +400,8 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
const constant int& out_vec_size,
|
const constant int& out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
threadgroup float* lut) {
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int packs_per_thread = 1;
|
constexpr int packs_per_thread = 1;
|
||||||
@ -372,6 +418,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
@ -402,7 +449,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
const device auto* sl = scales + row * in_vec_size_g;
|
const device auto* sl = scales + row * in_vec_size_g;
|
||||||
|
|
||||||
S s = sl[0];
|
S s = sl[0];
|
||||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
ws += block_size * bytes_per_pack / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
@ -420,8 +467,8 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
const device auto* sl = scales + row * in_vec_size_g;
|
const device auto* sl = scales + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
U s = dequantize_scale<U>(sl[0]);
|
||||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -449,8 +496,8 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
const device auto* sl = scales + row * in_vec_size_g;
|
const device auto* sl = scales + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
U s = dequantize_scale<U>(sl[0]);
|
||||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
ws += block_size * bytes_per_pack / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
@ -468,9 +515,9 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
const device auto* sl = scales + row * in_vec_size_g;
|
const device auto* sl = scales + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
U s = dequantize_scale<U>(sl[0]);
|
||||||
result[row] +=
|
result[row] +=
|
||||||
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
|
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
@ -492,7 +539,8 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
const int out_vec_size,
|
const int out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
threadgroup float* lut) {
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int pack_factor = get_pack_factor<32>();
|
constexpr int pack_factor = get_pack_factor<32>();
|
||||||
constexpr int bytes_per_pack = get_bytes_per_pack();
|
constexpr int bytes_per_pack = get_bytes_per_pack();
|
||||||
@ -513,6 +561,8 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
thread U scale = 0;
|
thread U scale = 0;
|
||||||
thread U x_local = 0;
|
thread U x_local = 0;
|
||||||
|
|
||||||
|
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||||
const int out_vec_size_g = out_vec_size / group_size;
|
const int out_vec_size_g = out_vec_size / group_size;
|
||||||
@ -531,10 +581,10 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
if (remaining == 0) {
|
if (remaining == 0) {
|
||||||
for (int i = 0; i < in_vec_size; i += block_size) {
|
for (int i = 0; i < in_vec_size; i += block_size) {
|
||||||
x_local = *x;
|
x_local = *x;
|
||||||
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
|
scale = dequantize_scale<U>(*scales);
|
||||||
w_local = *((device vec_w*)ws);
|
w_local = *((device vec_w*)ws);
|
||||||
qouter<U, tn * pack_factor>(
|
qouter<U, tn * pack_factor>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||||
|
|
||||||
x += block_size;
|
x += block_size;
|
||||||
scales += block_size * out_vec_size_g;
|
scales += block_size * out_vec_size_g;
|
||||||
@ -543,11 +593,11 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
} else {
|
} else {
|
||||||
for (int i = block_size; i < in_vec_size; i += block_size) {
|
for (int i = block_size; i < in_vec_size; i += block_size) {
|
||||||
x_local = *x;
|
x_local = *x;
|
||||||
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
|
scale = dequantize_scale<U>(*scales);
|
||||||
w_local = *((device vec_w*)ws);
|
w_local = *((device vec_w*)ws);
|
||||||
|
|
||||||
qouter<U, tn * pack_factor>(
|
qouter<U, tn * pack_factor>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||||
|
|
||||||
x += block_size;
|
x += block_size;
|
||||||
scales += block_size * out_vec_size_g;
|
scales += block_size * out_vec_size_g;
|
||||||
@ -555,14 +605,14 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
}
|
}
|
||||||
if (static_cast<int>(simd_lid) < remaining) {
|
if (static_cast<int>(simd_lid) < remaining) {
|
||||||
x_local = *x;
|
x_local = *x;
|
||||||
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
|
scale = dequantize_scale<U>(*scales);
|
||||||
w_local = *((device vec_w*)ws);
|
w_local = *((device vec_w*)ws);
|
||||||
} else {
|
} else {
|
||||||
x_local = 0;
|
x_local = 0;
|
||||||
scale = 0;
|
scale = 0;
|
||||||
}
|
}
|
||||||
qouter<U, tn * pack_factor>(
|
qouter<U, tn * pack_factor>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate in the simdgroup
|
// Accumulate in the simdgroup
|
||||||
@ -601,7 +651,8 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
threadgroup T* lut) {
|
||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
@ -646,7 +697,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
|||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
const short num_outs = min(BN, N - y_col);
|
const short num_outs = min(BN, N - y_col);
|
||||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);
|
loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid);
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
if (num_els < BM) {
|
if (num_els < BM) {
|
||||||
@ -725,7 +776,8 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
threadgroup T* lut) {
|
||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
@ -767,7 +819,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
|||||||
// Make the x loader and mma operation
|
// Make the x loader and mma operation
|
||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid);
|
loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid);
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
if (num_els < BM) {
|
if (num_els < BM) {
|
||||||
@ -941,7 +993,9 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
|||||||
const constant int64_t* s_strides,
|
const constant int64_t* s_strides,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
uint quad_lid [[thread_index_in_quadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
if (batched) {
|
if (batched) {
|
||||||
int M = x_shape[x_batch_ndims];
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets(
|
adjust_matrix_offsets(
|
||||||
@ -959,8 +1013,20 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_quad_impl<T, group_size, D>(
|
mxfp4_qmv_quad_impl<T, group_size, D>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
|
w,
|
||||||
|
scales,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
quad_gid,
|
||||||
|
quad_lid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid,
|
||||||
|
lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, bool batched, typename S>
|
template <typename T, int group_size, bool batched, typename S>
|
||||||
@ -998,8 +1064,9 @@ template <typename T, int group_size, bool batched, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_fast_impl<T, group_size>(
|
mxfp4_qmv_fast_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, bool batched, typename S>
|
template <typename T, const int group_size, bool batched, typename S>
|
||||||
@ -1037,8 +1104,9 @@ template <typename T, const int group_size, bool batched, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_impl<T, group_size>(
|
mxfp4_qmv_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, bool batched, typename S>
|
template <typename T, const int group_size, bool batched, typename S>
|
||||||
@ -1076,8 +1144,9 @@ template <typename T, const int group_size, bool batched, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qvm_impl<T, group_size>(
|
mxfp4_qvm_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, int split_k = 32, typename S>
|
template <typename T, const int group_size, int split_k = 32, typename S>
|
||||||
@ -1119,8 +1188,18 @@ template <typename T, const int group_size, int split_k = 32, typename S>
|
|||||||
int in_vec_size_adj =
|
int in_vec_size_adj =
|
||||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||||
|
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qvm_impl<T, group_size>(
|
mxfp4_qvm_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
|
w,
|
||||||
|
scales,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size_adj,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid,
|
||||||
|
lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -1157,6 +1236,7 @@ template <
|
|||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BN * BK_padded];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
threadgroup T lut[16];
|
||||||
|
|
||||||
if (batched) {
|
if (batched) {
|
||||||
adjust_matrix_offsets(
|
adjust_matrix_offsets(
|
||||||
@ -1175,7 +1255,7 @@ template <
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -1212,6 +1292,7 @@ template <
|
|||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BK * BN_padded];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
|
threadgroup T lut[16];
|
||||||
|
|
||||||
if (batched) {
|
if (batched) {
|
||||||
adjust_matrix_offsets(
|
adjust_matrix_offsets(
|
||||||
@ -1231,7 +1312,7 @@ template <
|
|||||||
}
|
}
|
||||||
|
|
||||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, typename S>
|
||||||
@ -1279,8 +1360,9 @@ template <typename T, int group_size, typename S>
|
|||||||
w_strides,
|
w_strides,
|
||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_fast_impl<T, group_size>(
|
mxfp4_qmv_fast_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, typename S>
|
||||||
@ -1328,8 +1410,9 @@ template <typename T, int group_size, typename S>
|
|||||||
w_strides,
|
w_strides,
|
||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_impl<T, group_size>(
|
mxfp4_qmv_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, typename S>
|
||||||
@ -1377,8 +1460,9 @@ template <typename T, int group_size, typename S>
|
|||||||
w_strides,
|
w_strides,
|
||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
|
threadgroup float lut[16];
|
||||||
mxfp4_qvm_impl<T, group_size>(
|
mxfp4_qvm_impl<T, group_size>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -1420,6 +1504,7 @@ template <
|
|||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BN * BK_padded];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
threadgroup T lut[16];
|
||||||
|
|
||||||
adjust_matrix_offsets(
|
adjust_matrix_offsets(
|
||||||
x,
|
x,
|
||||||
@ -1442,7 +1527,7 @@ template <
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -1484,6 +1569,7 @@ template <
|
|||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BK * BN_padded];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
|
threadgroup T lut[16];
|
||||||
|
|
||||||
adjust_matrix_offsets(
|
adjust_matrix_offsets(
|
||||||
x,
|
x,
|
||||||
@ -1506,7 +1592,7 @@ template <
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
@ -1621,6 +1707,7 @@ template <
|
|||||||
constexpr int bytes_per_pack = get_bytes_per_pack();
|
constexpr int bytes_per_pack = get_bytes_per_pack();
|
||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
threadgroup T lut[16];
|
||||||
|
|
||||||
using mma_t = mlx::steel::BlockMMA<
|
using mma_t = mlx::steel::BlockMMA<
|
||||||
T,
|
T,
|
||||||
@ -1709,6 +1796,7 @@ template <
|
|||||||
scales + index * stride_s,
|
scales + index * stride_s,
|
||||||
transpose ? K : N,
|
transpose ? K : N,
|
||||||
Ws,
|
Ws,
|
||||||
|
lut,
|
||||||
simd_group_id,
|
simd_group_id,
|
||||||
simd_lane_id);
|
simd_lane_id);
|
||||||
|
|
||||||
|
@ -734,6 +734,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
for L, K, D, E, I, transpose, mode in parameters:
|
for L, K, D, E, I, transpose, mode in parameters:
|
||||||
if mode == "mxfp4":
|
if mode == "mxfp4":
|
||||||
group_size = 32
|
group_size = 32
|
||||||
|
else:
|
||||||
|
group_size = 64
|
||||||
K, D = (K, D) if transpose else (D, K)
|
K, D = (K, D) if transpose else (D, K)
|
||||||
ishape = (L, I)
|
ishape = (L, I)
|
||||||
xshape = (L, 1, 1, K)
|
xshape = (L, 1, 1, K)
|
||||||
|
Loading…
Reference in New Issue
Block a user