Quantize with groups of 32 (#511)

* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-01-21 06:19:05 -08:00 committed by GitHub
parent 92c22c1ea3
commit 7a34e46677
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 66 additions and 27 deletions

View File

@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
@ -84,6 +87,15 @@ quant_matmul = {
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),

View File

@ -119,6 +119,12 @@ void _qmm_dispatch_typed(
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
@ -135,6 +141,12 @@ void _qmm_dispatch_typed(
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
@ -151,6 +163,12 @@ void _qmm_dispatch_typed(
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);

View File

@ -142,10 +142,11 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = (tid.y * BN + simd_gid) * el_per_int;
int out_col_start = tid.y * (BN * el_per_int);
int out_col = out_col_start + simd_gid * el_per_int;
w += out_col / el_per_int;
scales += out_col / group_size;
biases += out_col / group_size;
scales += out_col_start / group_size;
biases += out_col_start / group_size;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col;
@ -155,26 +156,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
int offset = simd_lid + i;
bool thread_in_bounds = offset < in_vec_size;
int offset_lid = simd_lid + i;
int offset_gid = simd_gid + i;
bool thread_in_bounds = offset_lid < in_vec_size;
bool group_in_bounds = offset_gid < in_vec_size;
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
}
// Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
}
if (simd_lid < groups_per_block && group_in_bounds) {
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -184,7 +181,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
// Do all the work.
#pragma clang loop unroll(full)
@ -543,6 +540,9 @@ instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
@ -570,6 +570,9 @@ instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
@ -601,6 +604,9 @@ instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
@ -629,3 +635,6 @@ instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)

View File

@ -2845,7 +2845,7 @@ std::tuple<array, array, array> quantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (group_size != 64 && group_size != 128) {
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";

View File

@ -140,7 +140,6 @@ class TestLosses(mlx_tests.MLXTestCase):
probs, targets, with_logits=False, reduction="none"
)
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
print(losses_none, expected_none)
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'

View File

@ -10,18 +10,19 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(w, 64, b)
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(w, gs, b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[128, 64, 32], # group_size
[2, 4, 8], # bits
[8, 32, 33, 64], # M
[512, 1024], # N
@ -75,7 +76,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
@ -97,7 +98,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N