Simplifying and improving qmm (#1030)

This commit is contained in:
Angelos Katharopoulos 2024-04-24 13:07:45 -07:00 committed by GitHub
parent ec8578d41a
commit 20a01bbd9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 262 additions and 202 deletions

View File

@ -201,6 +201,150 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
} }
} }
template <typename U, int N, int bits>
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
if (bits == 2) {
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
for (int i = 0; i < (N / 4); i++) {
w_local[4*i] = s[0] * (w[i] & 0x03) + bias;
w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias;
w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias;
w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias;
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)};
for (int i = 0; i < (N / 4); i++) {
w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias;
w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias;
w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias;
w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias;
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
w_local[i] = scale * w[i] + bias;
}
}
}
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct QuantizedBlockLoader {
static_assert(BCOLS <= group_size, "The group size should be larger than the columns");
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns");
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
threadgroup T* dst;
const device uint32_t* src;
const device T* scales;
const device T* biases;
QuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const device T* biases_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
group_step_cnt(0),
group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld / pack_factor + bj),
scales(scales_ + bi * src_ld / group_size),
biases(biases_ + bi * src_ld / group_size) {}
void load_unsafe() const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
T scale = *scales;
T bias = *biases;
for (int i=0; i<n_reads; i++) {
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
for (int i=0; i<n_reads*pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
for (int i=0; i<n_reads*pack_factor; i++) {
dst[i] = T(0);
}
return;
}
T scale = *scales;
T bias = *biases;
for (int i=0; i<n_reads; i++) {
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales++;
biases++;
}
} else {
scales++;
biases++;
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <typename T, int group_size, int bits, int packs_per_thread> template <typename T, int group_size, int bits, int packs_per_thread>
[[kernel]] void qmv_fast( [[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
@ -495,28 +639,23 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
const uint lidy = lid / SIMD_SIZE; (void)lid;
constexpr int WM = 2; constexpr int WM = 2;
constexpr int WN = 2; constexpr int WN = 2;
constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = 32 / bits;
constexpr int el_per_int = 32 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int ints_per_block = BK / el_per_int;
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
constexpr int groups_per_simd = BN / (WM * WN);
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
// Instantiate the appropriate BlockMMA and Loader // Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>; using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>; using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
threadgroup T scales_block[BN * groups_per_block]; threadgroup T Xs[BM * BK_padded];
threadgroup T biases_block[BN * groups_per_block]; threadgroup T Ws[BN * BK_padded];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BN * BK];
// Set the block // Set the block
const int K_w = K / el_per_int; const int K_w = K / pack_factor;
const int K_g = K / group_size; const int K_g = K / group_size;
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
const int y_col = tid.x * BN; const int y_col = tid.x * BN;
@ -531,98 +670,53 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const short num_els = min(BM, M - y_row); const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col); const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid);
for (int k=0; k<K; k += BK) { if (num_els < BM) {
threadgroup_barrier(mem_flags::mem_threadgroup); if (!aligned_N && num_outs < BN) {
// Load the x tile for (int k=0; k<K; k += BK) {
if (num_els < BM) { threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els)); loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else { } else {
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe(); loader_x.load_unsafe();
} loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the scale and bias mma_op.mma(Xs, Ws);
if (simd_lid == 0) { loader_x.next();
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd; loader_w.next();
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd; }
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size; } else {
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size; for (int k=0; k<K; k += BK) {
#pragma clang loop unroll(full) threadgroup_barrier(mem_flags::mem_threadgroup);
for (int gs=0; gs<groups_per_simd; gs++) { loader_x.load_unsafe();
#pragma clang loop unroll(full) loader_w.load_unsafe();
for (int gc=0; gc<groups_per_block; gc++) { threadgroup_barrier(mem_flags::mem_threadgroup);
scales_block_local[gc] = scales_local[gc]; mma_op.mma(Xs, Ws);
biases_block_local[gc] = biases_local[gc]; loader_x.next();
} loader_w.next();
scales_block_local += groups_per_block;
scales_local += K_g;
biases_block_local += groups_per_block;
biases_local += K_g;
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the w tile
{
if (!aligned_N && num_outs < BN) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
// y_col corresponds to the row of the weight matrix and added to
// offset_row it should be less than the total number of rows
// otherwise skip.
if (y_col + offset_row < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(Xs, Ws);
// Prepare for next iteration
loader_x.next();
w += ints_per_block;
// scales and biases cannot be advanced because they would have to be
// advanced every other iteration or sth.
} }
// Store results to device memory // Store results to device memory
@ -653,32 +747,27 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
const uint lidy = lid / SIMD_SIZE; (void)lid;
constexpr int WM = 2; constexpr int WM = 2;
constexpr int WN = 2; constexpr int WN = 2;
constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = 32 / bits;
constexpr int el_per_int = 32 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1; constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int groups_per_simd = BK / (WM * WN);
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
// Instantiate the appropriate BlockMMA and Loader // Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK, BN>; using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>; using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>;
threadgroup T scales_block[BK * groups_per_block]; threadgroup T Xs[BM * BK_padded];
threadgroup T biases_block[BK * groups_per_block]; threadgroup T Ws[BK * BN_padded];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BK * BN];
// Set the block // Set the block
const int N_w = N / el_per_int;
const int N_g = N / group_size;
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
const int y_col = tid.x * BN; const int y_col = tid.x * BN;
x += y_row * K; x += y_row * K;
w += y_col / el_per_int; w += y_col / pack_factor;
scales += y_col / group_size; scales += y_col / group_size;
biases += y_col / group_size; biases += y_col / group_size;
y += y_row * N + y_col; y += y_row * N + y_col;
@ -686,96 +775,67 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Make the x loader and mma operation // Make the x loader and mma operation
const short num_els = min(BM, M - y_row); const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid);
for (int k=0; k<K; k += BK) { if (num_els < BM) {
threadgroup_barrier(mem_flags::mem_threadgroup); if ((K % BK) != 0) {
// Load the x tile const int k_blocks = K/BK;
short num_k = min(BK, K - k); for (int k=0; k<k_blocks; k++) {
if (num_els < BM || num_k < BK) { threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(num_k, num_els)); loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
const short num_k = K - k_blocks * BK;
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(num_k, num_els));
loader_w.load_safe(short2(BN, num_k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else { } else {
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if ((K % BK) != 0) {
const int k_blocks = K/BK;
for (int k=0; k<k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe(); loader_x.load_unsafe();
} loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the scale and bias mma_op.mma(Xs, Ws);
if (simd_lid == 0) { loader_x.next();
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd; loader_w.next();
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd; }
const device T *scales_local = scales + lidy * groups_per_simd * N_g; const short num_k = K - k_blocks * BK;
const device T *biases_local = biases + lidy * groups_per_simd * N_g; threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma clang loop unroll(full) loader_x.load_safe(short2(num_k, BM));
for (int gs=0; gs<groups_per_simd; gs++) { loader_w.load_safe(short2(BN, num_k));
#pragma clang loop unroll(full) threadgroup_barrier(mem_flags::mem_threadgroup);
for (int gc=0; gc<groups_per_block; gc++) { mma_op.mma(Xs, Ws);
scales_block_local[gc] = scales_local[gc]; } else {
biases_block_local[gc] = biases_local[gc]; for (int k=0; k<K; k += BK) {
} threadgroup_barrier(mem_flags::mem_threadgroup);
scales_block_local += groups_per_block; loader_x.load_unsafe();
scales_local += N_g; loader_w.load_unsafe();
biases_block_local += groups_per_block; threadgroup_barrier(mem_flags::mem_threadgroup);
biases_local += N_g; mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the w tile
{
if (num_k < BK) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
if (k + offset_row < K) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(Xs, Ws);
// Prepare for next iteration
loader_x.next();
w += BK * N_w;
scales += BK * N_g;
biases += BK * N_g;
} }
// Store results to device memory // Store results to device memory
@ -877,7 +937,7 @@ instantiate_qvm_types( 32, 8)
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ #define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \ [[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \ const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \ const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \ const device itype* scales [[buffer(2)]], \
@ -911,7 +971,7 @@ instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \ #define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_n<itype, 32, 32, 64, group_size, bits>( \ [[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \ const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \ const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \ const device itype* scales [[buffer(2)]], \

View File

@ -110,7 +110,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int wm = 2; int wm = 2;
int bm = 32; int bm = 32;
int bn = 32; int bn = 32;
int bk = 64; int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1); MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
@ -167,7 +167,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int wn = 2; int wn = 2;
int wm = 2; int wm = 2;
int bm = 32; int bm = 32;
int bn = 64; int bn = 32;
int bk = 32; int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);