Faster bfloat quantized mat-vec and vec-mat (#663)

This commit is contained in:
Awni Hannun 2024-02-11 21:53:16 -08:00 committed by GitHub
parent d12573daa6
commit 3756381358
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,14 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T> struct AccT {
typedef T acc_t;
};
template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
@ -37,15 +45,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[colgroup];
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[colgroup];
thread uint32_t w_local;
thread T result = 0;
thread T scale = 1;
thread T bias = 0;
thread T x_thread[el_per_thread];
thread U result = 0;
thread U scale = 1;
thread U bias = 0;
thread U x_thread[el_per_thread];
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
@ -90,7 +99,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= bits;
}
}
@ -100,7 +109,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Store the result
if (simd_lid == 0) {
y[out_row] = result;
y[out_row] = static_cast<T>(result);
}
}
@ -129,15 +138,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[BM];
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[BM];
thread uint32_t w_local;
thread T result[el_per_int] = {0};
thread T scale = 1;
thread T bias = 0;
thread T x_local = 0;
thread U result[el_per_int] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
@ -186,7 +196,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
w_local >>= bits;
}
}
@ -201,7 +211,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
y[k] = result[k];
y[k] = static_cast<T>(result[k]);
}
}
}