Add group size 32 int qmms

This commit is contained in:
Jagrit Digani
2025-11-19 12:00:59 -08:00
parent 18807aae0b
commit 166dfac5cf
3 changed files with 146 additions and 5 deletions

View File

@@ -692,6 +692,146 @@ struct QuantizedBlockLoader {
} }
}; };
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short bits>
struct QuantizedBlockLoader<
T,
BROWS,
BCOLS,
dst_ld,
reduction_dim,
tgp_size,
32,
bits> {
MLX_MTL_CONST short group_size = 32;
static_assert(
BCOLS % group_size == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short n_groups = BCOLS / group_size;
static_assert(
(BCOLS_PACKED / n_reads) == n_groups,
"Other configurations are not yet supported");
const int src_ld;
const int tile_stride;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short group_id;
threadgroup T* dst;
const device uint8_t* src;
const device T* scales;
const device T* biases;
QuantizedBlockLoader(
const device uint8_t* src_,
const device T* scales_,
const device T* biases_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
group_id((bj * pack_factor) / group_size),
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size + group_id),
biases(biases_ + bi * src_ld / group_size + group_id) {}
void load_unsafe() const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
T scale = *scales;
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
T scale = *scales;
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i * bytes_per_pack),
scale,
bias,
dst + i * pack_factor);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
// if (group_steps > 1) {
// group_step_cnt++;
// if (group_step_cnt == group_steps) {
// group_step_cnt = 0;
// scales++;
// biases++;
// }
// } else {
scales += n_groups;
biases += n_groups;
// }
} else {
scales += n_groups * group_stride;
biases += n_groups * group_stride;
}
}
};
template <typename T> template <typename T>
METAL_FUNC void adjust_matrix_offsets( METAL_FUNC void adjust_matrix_offsets(
const device T*& x, const device T*& x,
@@ -843,7 +983,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl(
biases += y_col * K_g; biases += y_col * K_g;
y += y_row * static_cast<int64_t>(N) + y_col; y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation // Make the weight loader
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
constexpr short UM = 16; constexpr short UM = 16;

View File

@@ -92,7 +92,8 @@
#define instantiate_quantized_groups(bits) \ #define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \ instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \ #define instantiate_quantized_all() \
instantiate_quantized_groups(2) \ instantiate_quantized_groups(2) \

View File

@@ -675,7 +675,7 @@ void qmm(
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && if (metal::is_nax_available() && transpose &&
(x.dtype() != float32 || env::enable_tf32()) && mode == "affine" && (x.dtype() != float32 || env::enable_tf32()) && mode == "affine" &&
(group_size >= 64) && (K % 64 == 0)) { (K % 64 == 0)) {
return qmm_nax( return qmm_nax(
/* const array& x = */ x, /* const array& x = */ x,
/* const array& w = */ w, /* const array& w = */ w,
@@ -777,8 +777,8 @@ void gather_qmm(
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && if (metal::is_nax_available() && transpose &&
(x.dtype() != float32 || env::enable_tf32()) && transpose && (x.dtype() != float32 || env::enable_tf32()) && mode == "affine" &&
mode == "affine" && (group_size >= 64) && (K % 64 == 0)) { (K % 64 == 0)) {
return gather_qmm_nax( return gather_qmm_nax(
/* const array& x = */ x, /* const array& x = */ x,
/* const array& w = */ w, /* const array& w = */ w,