Faster bfloat quantized mat-vec and vec-mat (#663)

This commit is contained in:
Awni Hannun 2024-02-11 21:53:16 -08:00 committed by GitHub
parent d12573daa6
commit 3756381358
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,14 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T> struct AccT {
typedef T acc_t;
};
template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, const int BM, const int BN, const int group_size, const int bits> template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
@ -37,15 +45,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int groups_per_block = colgroup / group_size; constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE; constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block]; typedef typename AccT<T>::acc_t U;
threadgroup T biases_block[BM * groups_per_block]; threadgroup U scales_block[BM * groups_per_block];
threadgroup T x_block[colgroup]; threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[colgroup];
thread uint32_t w_local; thread uint32_t w_local;
thread T result = 0; thread U result = 0;
thread T scale = 1; thread U scale = 1;
thread T bias = 0; thread U bias = 0;
thread T x_thread[el_per_thread]; thread U x_thread[el_per_thread];
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread; const int in_vec_size_w = in_vec_size / el_per_thread;
@ -90,7 +99,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work. // Do all the work.
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) { for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k]; result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= bits; w_local >>= bits;
} }
} }
@ -100,7 +109,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Store the result // Store the result
if (simd_lid == 0) { if (simd_lid == 0) {
y[out_row] = result; y[out_row] = static_cast<T>(result);
} }
} }
@ -129,15 +138,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int colgroup = BN * el_per_int; constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size; constexpr int groups_per_block = colgroup / group_size;
threadgroup T scales_block[BM * groups_per_block]; typedef typename AccT<T>::acc_t U;
threadgroup T biases_block[BM * groups_per_block]; threadgroup U scales_block[BM * groups_per_block];
threadgroup T x_block[BM]; threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[BM];
thread uint32_t w_local; thread uint32_t w_local;
thread T result[el_per_int] = {0}; thread U result[el_per_int] = {0};
thread T scale = 1; thread U scale = 1;
thread T bias = 0; thread U bias = 0;
thread T x_local = 0; thread U x_local = 0;
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int; const int out_vec_size_w = out_vec_size / el_per_int;
@ -186,7 +196,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work. // Do all the work.
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) { for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local; result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
w_local >>= bits; w_local >>= bits;
} }
} }
@ -201,7 +211,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
if (simd_lid == 0) { if (simd_lid == 0) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) { for (int k=0; k<el_per_int; k++) {
y[k] = result[k]; y[k] = static_cast<T>(result[k]);
} }
} }
} }