mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 13:51:13 +08:00
Faster bfloat quantized mat-vec and vec-mat (#663)
This commit is contained in:
parent
d12573daa6
commit
3756381358
@ -15,6 +15,14 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@ -37,15 +45,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[colgroup];
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[colgroup];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result = 0;
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_thread[el_per_thread];
|
||||
thread U result = 0;
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_thread[el_per_thread];
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
@ -90,7 +99,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
@ -100,7 +109,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
y[out_row] = result;
|
||||
y[out_row] = static_cast<T>(result);
|
||||
}
|
||||
}
|
||||
|
||||
@ -129,15 +138,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[BM];
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result[el_per_int] = {0};
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_local = 0;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
@ -186,7 +196,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
@ -201,7 +211,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
y[k] = result[k];
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user