mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32 * missing cpu dispatch * remove print * Fix qvm for group_size 32 --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
92c22c1ea3
commit
7a34e46677
@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
||||
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
||||
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||
@ -84,6 +87,15 @@ quant_matmul = {
|
||||
"quant_matmul_128_8": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||
),
|
||||
"quant_matmul_t_32_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=2
|
||||
),
|
||||
"quant_matmul_t_32_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=4
|
||||
),
|
||||
"quant_matmul_t_32_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=8
|
||||
),
|
||||
"quant_matmul_t_64_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||
),
|
||||
|
@ -119,6 +119,12 @@ void _qmm_dispatch_typed(
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
@ -135,6 +141,12 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
case 4: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
@ -151,6 +163,12 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
case 8: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
|
@ -142,10 +142,11 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = (tid.y * BN + simd_gid) * el_per_int;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@ -155,26 +156,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset = simd_lid + i;
|
||||
bool thread_in_bounds = offset < in_vec_size;
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -184,7 +181,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
@ -543,6 +540,9 @@ instantiate_qmv_types(128, 8)
|
||||
instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
instantiate_qmv_types( 32, 2)
|
||||
instantiate_qmv_types( 32, 4)
|
||||
instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
@ -570,6 +570,9 @@ instantiate_qvm_types(128, 8)
|
||||
instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
instantiate_qvm_types( 32, 2)
|
||||
instantiate_qvm_types( 32, 4)
|
||||
instantiate_qvm_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||
@ -601,6 +604,9 @@ instantiate_qmm_t_types(128, 8)
|
||||
instantiate_qmm_t_types( 64, 2)
|
||||
instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
instantiate_qmm_t_types( 32, 2)
|
||||
instantiate_qmm_t_types( 32, 4)
|
||||
instantiate_qmm_t_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
@ -629,3 +635,6 @@ instantiate_qmm_n_types(128, 8)
|
||||
instantiate_qmm_n_types( 64, 2)
|
||||
instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
||||
instantiate_qmm_n_types( 32, 2)
|
||||
instantiate_qmm_n_types( 32, 4)
|
||||
instantiate_qmm_n_types( 32, 8)
|
||||
|
@ -2845,7 +2845,7 @@ std::tuple<array, array, array> quantize(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (group_size != 64 && group_size != 128) {
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||
|
@ -140,7 +140,6 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
probs, targets, with_logits=False, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
|
||||
print(losses_none, expected_none)
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
|
@ -10,9 +10,10 @@ import mlx_tests
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 512))
|
||||
for gs in [32, 64, 128]:
|
||||
for b in [2, 4, 8]:
|
||||
w_q, scales, biases = mx.quantize(w, 64, b)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
||||
w_q, scales, biases = mx.quantize(w, gs, b)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||
eps = 1e-6
|
||||
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
|
||||
@ -21,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64], # group_size
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[8, 32, 33, 64], # M
|
||||
[512, 1024], # N
|
||||
@ -75,7 +76,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64], # group_size
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
@ -97,7 +98,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64], # group_size
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
|
Loading…
Reference in New Issue
Block a user