mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
Faster bfloat quantized mat-vec and vec-mat (#663)
This commit is contained in:
parent
d12573daa6
commit
3756381358
@ -15,6 +15,14 @@ using namespace metal;
|
|||||||
|
|
||||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
|
|
||||||
|
template <typename T> struct AccT {
|
||||||
|
typedef T acc_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct AccT<bfloat16_t> {
|
||||||
|
typedef float acc_t;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
@ -37,15 +45,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
constexpr int groups_per_block = colgroup / group_size;
|
constexpr int groups_per_block = colgroup / group_size;
|
||||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||||
|
|
||||||
threadgroup T scales_block[BM * groups_per_block];
|
typedef typename AccT<T>::acc_t U;
|
||||||
threadgroup T biases_block[BM * groups_per_block];
|
threadgroup U scales_block[BM * groups_per_block];
|
||||||
threadgroup T x_block[colgroup];
|
threadgroup U biases_block[BM * groups_per_block];
|
||||||
|
threadgroup U x_block[colgroup];
|
||||||
|
|
||||||
thread uint32_t w_local;
|
thread uint32_t w_local;
|
||||||
thread T result = 0;
|
thread U result = 0;
|
||||||
thread T scale = 1;
|
thread U scale = 1;
|
||||||
thread T bias = 0;
|
thread U bias = 0;
|
||||||
thread T x_thread[el_per_thread];
|
thread U x_thread[el_per_thread];
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||||
@ -90,7 +99,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
// Do all the work.
|
// Do all the work.
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k=0; k<el_per_thread; k++) {
|
for (int k=0; k<el_per_thread; k++) {
|
||||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
|
||||||
w_local >>= bits;
|
w_local >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,7 +109,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
|
|
||||||
// Store the result
|
// Store the result
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
y[out_row] = result;
|
y[out_row] = static_cast<T>(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,15 +138,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
constexpr int colgroup = BN * el_per_int;
|
constexpr int colgroup = BN * el_per_int;
|
||||||
constexpr int groups_per_block = colgroup / group_size;
|
constexpr int groups_per_block = colgroup / group_size;
|
||||||
|
|
||||||
threadgroup T scales_block[BM * groups_per_block];
|
typedef typename AccT<T>::acc_t U;
|
||||||
threadgroup T biases_block[BM * groups_per_block];
|
threadgroup U scales_block[BM * groups_per_block];
|
||||||
threadgroup T x_block[BM];
|
threadgroup U biases_block[BM * groups_per_block];
|
||||||
|
threadgroup U x_block[BM];
|
||||||
|
|
||||||
thread uint32_t w_local;
|
thread uint32_t w_local;
|
||||||
thread T result[el_per_int] = {0};
|
thread U result[el_per_int] = {0};
|
||||||
thread T scale = 1;
|
thread U scale = 1;
|
||||||
thread T bias = 0;
|
thread U bias = 0;
|
||||||
thread T x_local = 0;
|
thread U x_local = 0;
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||||
@ -186,7 +196,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
// Do all the work.
|
// Do all the work.
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k=0; k<el_per_int; k++) {
|
for (int k=0; k<el_per_int; k++) {
|
||||||
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
|
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||||
w_local >>= bits;
|
w_local >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -201,7 +211,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k=0; k<el_per_int; k++) {
|
for (int k=0; k<el_per_int; k++) {
|
||||||
y[k] = result[k];
|
y[k] = static_cast<T>(result[k]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user