mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Simplifying and improving qmm (#1030)
This commit is contained in:
parent
ec8578d41a
commit
20a01bbd9f
@ -201,6 +201,150 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U, int N, int bits>
|
||||||
|
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||||
|
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
|
if (bits == 2) {
|
||||||
|
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
|
||||||
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
|
w_local[4*i] = s[0] * (w[i] & 0x03) + bias;
|
||||||
|
w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias;
|
||||||
|
w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias;
|
||||||
|
w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (bits == 4) {
|
||||||
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
|
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)};
|
||||||
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
|
w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias;
|
||||||
|
w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias;
|
||||||
|
w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias;
|
||||||
|
w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (bits == 8) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
w_local[i] = scale * w[i] + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
short BROWS,
|
||||||
|
short BCOLS,
|
||||||
|
short dst_ld,
|
||||||
|
short reduction_dim,
|
||||||
|
short tgp_size,
|
||||||
|
short group_size,
|
||||||
|
short bits>
|
||||||
|
struct QuantizedBlockLoader {
|
||||||
|
static_assert(BCOLS <= group_size, "The group size should be larger than the columns");
|
||||||
|
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns");
|
||||||
|
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
|
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||||
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||||
|
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||||
|
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||||
|
|
||||||
|
const int src_ld;
|
||||||
|
const int tile_stride;
|
||||||
|
short group_step_cnt;
|
||||||
|
const int group_stride;
|
||||||
|
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
|
||||||
|
threadgroup T* dst;
|
||||||
|
const device uint32_t* src;
|
||||||
|
const device T* scales;
|
||||||
|
const device T* biases;
|
||||||
|
|
||||||
|
QuantizedBlockLoader(
|
||||||
|
const device uint32_t* src_,
|
||||||
|
const device T* scales_,
|
||||||
|
const device T* biases_,
|
||||||
|
const int src_ld_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: src_ld(src_ld_),
|
||||||
|
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
||||||
|
group_step_cnt(0),
|
||||||
|
group_stride(BROWS * src_ld / group_size),
|
||||||
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||||
|
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||||
|
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
||||||
|
src(src_ + bi * src_ld / pack_factor + bj),
|
||||||
|
scales(scales_ + bi * src_ld / group_size),
|
||||||
|
biases(biases_ + bi * src_ld / group_size) {}
|
||||||
|
|
||||||
|
void load_unsafe() const {
|
||||||
|
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
T scale = *scales;
|
||||||
|
T bias = *biases;
|
||||||
|
for (int i=0; i<n_reads; i++) {
|
||||||
|
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void load_safe(short2 src_tile_dim) const {
|
||||||
|
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
|
||||||
|
for (int i=0; i<n_reads*pack_factor; i++) {
|
||||||
|
dst[i] = T(0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
|
||||||
|
for (int i=0; i<n_reads*pack_factor; i++) {
|
||||||
|
dst[i] = T(0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
T scale = *scales;
|
||||||
|
T bias = *biases;
|
||||||
|
for (int i=0; i<n_reads; i++) {
|
||||||
|
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void next() {
|
||||||
|
src += tile_stride;
|
||||||
|
if (reduction_dim == 1) {
|
||||||
|
if (group_steps > 1) {
|
||||||
|
group_step_cnt++;
|
||||||
|
if (group_step_cnt == group_steps) {
|
||||||
|
group_step_cnt = 0;
|
||||||
|
scales++;
|
||||||
|
biases++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scales++;
|
||||||
|
biases++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scales += group_stride;
|
||||||
|
biases += group_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||||
[[kernel]] void qmv_fast(
|
[[kernel]] void qmv_fast(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
@ -495,28 +639,23 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
const uint lidy = lid / SIMD_SIZE;
|
(void)lid;
|
||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int el_per_int = 32 / bits;
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int ints_per_block = BK / el_per_int;
|
|
||||||
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
|
|
||||||
constexpr int groups_per_simd = BN / (WM * WN);
|
|
||||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
|
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
|
||||||
|
|
||||||
threadgroup T scales_block[BN * groups_per_block];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T biases_block[BN * groups_per_block];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
threadgroup T Xs[BM * BK];
|
|
||||||
threadgroup T Ws[BN * BK];
|
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int K_w = K / el_per_int;
|
const int K_w = K / pack_factor;
|
||||||
const int K_g = K / group_size;
|
const int K_g = K / group_size;
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
@ -531,98 +670,53 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
const short num_outs = min(BN, N - y_col);
|
const short num_outs = min(BN, N - y_col);
|
||||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
|
loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
for (int k=0; k<K; k += BK) {
|
if (num_els < BM) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
if (!aligned_N && num_outs < BN) {
|
||||||
// Load the x tile
|
for (int k=0; k<K; k += BK) {
|
||||||
if (num_els < BM) {
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_safe(short2(BK, num_els));
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
loader_w.load_safe(short2(BK, num_outs));
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
for (int k=0; k<K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
loader_w.load_unsafe();
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!aligned_N && num_outs < BN) {
|
||||||
|
for (int k=0; k<K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
}
|
loader_w.load_safe(short2(BK, num_outs));
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Load the scale and bias
|
mma_op.mma(Xs, Ws);
|
||||||
if (simd_lid == 0) {
|
loader_x.next();
|
||||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
loader_w.next();
|
||||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
}
|
||||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size;
|
} else {
|
||||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size;
|
for (int k=0; k<K; k += BK) {
|
||||||
#pragma clang loop unroll(full)
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
loader_x.load_unsafe();
|
||||||
#pragma clang loop unroll(full)
|
loader_w.load_unsafe();
|
||||||
for (int gc=0; gc<groups_per_block; gc++) {
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
scales_block_local[gc] = scales_local[gc];
|
mma_op.mma(Xs, Ws);
|
||||||
biases_block_local[gc] = biases_local[gc];
|
loader_x.next();
|
||||||
}
|
loader_w.next();
|
||||||
scales_block_local += groups_per_block;
|
|
||||||
scales_local += K_g;
|
|
||||||
biases_block_local += groups_per_block;
|
|
||||||
biases_local += K_g;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load the w tile
|
|
||||||
{
|
|
||||||
if (!aligned_N && num_outs < BN) {
|
|
||||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
|
||||||
int offset = lid * w_els_per_thread + wo;
|
|
||||||
int offset_row = offset / (BK / el_per_int);
|
|
||||||
int offset_col = offset % (BK / el_per_int);
|
|
||||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
|
||||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
|
||||||
|
|
||||||
// y_col corresponds to the row of the weight matrix and added to
|
|
||||||
// offset_row it should be less than the total number of rows
|
|
||||||
// otherwise skip.
|
|
||||||
if (y_col + offset_row < N) {
|
|
||||||
uint32_t wi = *w_local;
|
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int t=0; t<el_per_int; t++) {
|
|
||||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
|
||||||
wi >>= bits;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int t=0; t<el_per_int; t++) {
|
|
||||||
Ws_local[t] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
|
||||||
int offset = lid * w_els_per_thread + wo;
|
|
||||||
int offset_row = offset / (BK / el_per_int);
|
|
||||||
int offset_col = offset % (BK / el_per_int);
|
|
||||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
|
||||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
|
||||||
|
|
||||||
uint32_t wi = *w_local;
|
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int t=0; t<el_per_int; t++) {
|
|
||||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
|
||||||
wi >>= bits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(Xs, Ws);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_x.next();
|
|
||||||
w += ints_per_block;
|
|
||||||
// scales and biases cannot be advanced because they would have to be
|
|
||||||
// advanced every other iteration or sth.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
@ -653,32 +747,27 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
const uint lidy = lid / SIMD_SIZE;
|
(void)lid;
|
||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int el_per_int = 32 / bits;
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1;
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
constexpr int groups_per_simd = BK / (WM * WN);
|
|
||||||
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
|
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK, BN>;
|
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>;
|
||||||
|
|
||||||
threadgroup T scales_block[BK * groups_per_block];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T biases_block[BK * groups_per_block];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
threadgroup T Xs[BM * BK];
|
|
||||||
threadgroup T Ws[BK * BN];
|
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int N_w = N / el_per_int;
|
|
||||||
const int N_g = N / group_size;
|
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
x += y_row * K;
|
x += y_row * K;
|
||||||
w += y_col / el_per_int;
|
w += y_col / pack_factor;
|
||||||
scales += y_col / group_size;
|
scales += y_col / group_size;
|
||||||
biases += y_col / group_size;
|
biases += y_col / group_size;
|
||||||
y += y_row * N + y_col;
|
y += y_row * N + y_col;
|
||||||
@ -686,96 +775,67 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
// Make the x loader and mma operation
|
// Make the x loader and mma operation
|
||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
|
loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
for (int k=0; k<K; k += BK) {
|
if (num_els < BM) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
if ((K % BK) != 0) {
|
||||||
// Load the x tile
|
const int k_blocks = K/BK;
|
||||||
short num_k = min(BK, K - k);
|
for (int k=0; k<k_blocks; k++) {
|
||||||
if (num_els < BM || num_k < BK) {
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_safe(short2(num_k, num_els));
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
loader_w.load_unsafe();
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
const short num_k = K - k_blocks * BK;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_safe(short2(num_k, num_els));
|
||||||
|
loader_w.load_safe(short2(BN, num_k));
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
} else {
|
} else {
|
||||||
|
for (int k=0; k<K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
loader_w.load_unsafe();
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((K % BK) != 0) {
|
||||||
|
const int k_blocks = K/BK;
|
||||||
|
for (int k=0; k<k_blocks; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
}
|
loader_w.load_unsafe();
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Load the scale and bias
|
mma_op.mma(Xs, Ws);
|
||||||
if (simd_lid == 0) {
|
loader_x.next();
|
||||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
loader_w.next();
|
||||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
}
|
||||||
const device T *scales_local = scales + lidy * groups_per_simd * N_g;
|
const short num_k = K - k_blocks * BK;
|
||||||
const device T *biases_local = biases + lidy * groups_per_simd * N_g;
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
#pragma clang loop unroll(full)
|
loader_x.load_safe(short2(num_k, BM));
|
||||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
loader_w.load_safe(short2(BN, num_k));
|
||||||
#pragma clang loop unroll(full)
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (int gc=0; gc<groups_per_block; gc++) {
|
mma_op.mma(Xs, Ws);
|
||||||
scales_block_local[gc] = scales_local[gc];
|
} else {
|
||||||
biases_block_local[gc] = biases_local[gc];
|
for (int k=0; k<K; k += BK) {
|
||||||
}
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
scales_block_local += groups_per_block;
|
loader_x.load_unsafe();
|
||||||
scales_local += N_g;
|
loader_w.load_unsafe();
|
||||||
biases_block_local += groups_per_block;
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
biases_local += N_g;
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load the w tile
|
|
||||||
{
|
|
||||||
if (num_k < BK) {
|
|
||||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
|
||||||
int offset = lid * w_els_per_thread + wo;
|
|
||||||
int offset_row = offset / (BN / el_per_int);
|
|
||||||
int offset_col = offset % (BN / el_per_int);
|
|
||||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
|
||||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
|
||||||
|
|
||||||
if (k + offset_row < K) {
|
|
||||||
uint32_t wi = *w_local;
|
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int t=0; t<el_per_int; t++) {
|
|
||||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
|
||||||
wi >>= bits;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int t=0; t<el_per_int; t++) {
|
|
||||||
Ws_local[t] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
|
||||||
int offset = lid * w_els_per_thread + wo;
|
|
||||||
int offset_row = offset / (BN / el_per_int);
|
|
||||||
int offset_col = offset % (BN / el_per_int);
|
|
||||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
|
||||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
|
||||||
|
|
||||||
uint32_t wi = *w_local;
|
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int t=0; t<el_per_int; t++) {
|
|
||||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
|
||||||
wi >>= bits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(Xs, Ws);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_x.next();
|
|
||||||
w += BK * N_w;
|
|
||||||
scales += BK * N_g;
|
|
||||||
biases += BK * N_g;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
@ -877,7 +937,7 @@ instantiate_qvm_types( 32, 8)
|
|||||||
|
|
||||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \
|
[[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -911,7 +971,7 @@ instantiate_qmm_t_types( 32, 8)
|
|||||||
|
|
||||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||||
[[kernel]] void qmm_n<itype, 32, 32, 64, group_size, bits>( \
|
[[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
|
@ -110,7 +110,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int wm = 2;
|
int wm = 2;
|
||||||
int bm = 32;
|
int bm = 32;
|
||||||
int bn = 32;
|
int bn = 32;
|
||||||
int bk = 64;
|
int bk = 32;
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int wn = 2;
|
int wn = 2;
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
int bm = 32;
|
int bm = 32;
|
||||||
int bn = 64;
|
int bn = 32;
|
||||||
int bk = 32;
|
int bk = 32;
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user