Add the 3bit packed qmm_t

This commit is contained in:
Angelos Katharopoulos 2024-12-17 22:16:30 -08:00
parent d75a509234
commit c02e14c264

View File

@ -2504,12 +2504,149 @@ struct AffinePackedQuantizedBlockLoader {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += 8;
biases += 8;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += 8;
biases += 8;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct AffineScalesPackedQuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
MLX_MTL_CONST short row_pack_factor = 2;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
MLX_MTL_CONST short n_reads =
(TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short bii;
const device uint8_t* src;
const device T* scales;
const device T* biases;
threadgroup T* dst;
AffineScalesPackedQuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0),
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
bii(bi / row_pack_factor),
src(((const device uint8_t*)src_) +
bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
scales(
scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
bi % row_pack_factor),
biases(scales + row_pack_factor),
dst(dst_ + bi * dst_ld + bj * pack_factor) {}
void load_unsafe() const {
if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
return;
}
T scale = *scales;
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + bytes_per_pack * i),
scale,
bias,
dst + i * pack_factor);
}
}
void load_safe(short2 src_tile_dim) const {
if (TOTAL_READS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + bytes_per_pack * i * src_ld),
scale,
bias,
dst + i * dst_ld);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += group_stride;
@ -2545,10 +2682,11 @@ METAL_FUNC void affine_packed_qmm_t_impl(
(void)lid;
constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = 32 / bits;
constexpr int row_pack_factor = 4;
constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
@ -2556,7 +2694,7 @@ METAL_FUNC void affine_packed_qmm_t_impl(
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = AffinePackedQuantizedBlockLoader<
using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
T,
BN,
BK,
@ -2565,16 +2703,30 @@ METAL_FUNC void affine_packed_qmm_t_impl(
WM * WN * SIMD_SIZE,
group_size,
bits>;
using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
using loader_w_t = typename ConditionalType<
power_of_2_bits,
loader_fully_packed_t,
loader_scales_packed_t>::type;
// Set the block
const int K_w = K * row_pack_factor / pack_factor;
const int K_w =
(power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
const int K_g = K * 2 * row_pack_factor / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const int packed_y_col = tid.x * (BN / row_pack_factor);
x += y_row * K;
w += packed_y_col * K_w;
w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
scales += packed_y_col * K_g;
y += y_row * N + y_col;
@ -2692,9 +2844,6 @@ template <
s_strides,
tid);
}
if (bits & (bits - 1)) {
} else {
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}