Quantized matmul fix (#677)

* Fix qmv for small or unaligned matrices

* Fix qmm
This commit is contained in:
Angelos Katharopoulos
2024-02-12 18:54:21 -08:00
committed by GitHub
parent 4cc70290f7
commit 40c108766b
3 changed files with 81 additions and 9 deletions

View File

@@ -39,11 +39,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
(void)lid;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_thread = 32 / bits;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
@@ -66,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in
x += tid.z * in_vec_size;
y += tid.z * out_vec_size;
if (out_row >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid < simdgroups_fetching_vec) {
x_block[lid] = x[lid + i];
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
}
}
if (simd_lid == 0) {
#pragma clang loop unroll(full)
@@ -250,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK];
@@ -313,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (y_col + offset_col < N) {
if (y_row + offset_row < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
@@ -428,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
short num_k = min(BK, K - k);
if (num_els < BM || num_k < BK) {
loader_x.load_safe(short2(num_k, num_els));
} else {
loader_x.load_unsafe();
}
@@ -457,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
if (k + BK >= K) {
if (num_k < BK) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);

View File

@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);