mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32 * missing cpu dispatch * remove print * Fix qvm for group_size 32 --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
92c22c1ea3
commit
7a34e46677
@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
|||||||
|
|
||||||
|
|
||||||
quant_matmul = {
|
quant_matmul = {
|
||||||
|
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
||||||
|
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
||||||
|
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
||||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||||
@ -84,6 +87,15 @@ quant_matmul = {
|
|||||||
"quant_matmul_128_8": partial(
|
"quant_matmul_128_8": partial(
|
||||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||||
),
|
),
|
||||||
|
"quant_matmul_t_32_2": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=2
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_4": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=4
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_8": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=8
|
||||||
|
),
|
||||||
"quant_matmul_t_64_2": partial(
|
"quant_matmul_t_64_2": partial(
|
||||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||||
),
|
),
|
||||||
|
@ -119,6 +119,12 @@ void _qmm_dispatch_typed(
|
|||||||
switch (bits) {
|
switch (bits) {
|
||||||
case 2: {
|
case 2: {
|
||||||
switch (group_size) {
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
case 64:
|
case 64:
|
||||||
if (transposed_w) {
|
if (transposed_w) {
|
||||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
@ -135,6 +141,12 @@ void _qmm_dispatch_typed(
|
|||||||
}
|
}
|
||||||
case 4: {
|
case 4: {
|
||||||
switch (group_size) {
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
case 64:
|
case 64:
|
||||||
if (transposed_w) {
|
if (transposed_w) {
|
||||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
@ -151,6 +163,12 @@ void _qmm_dispatch_typed(
|
|||||||
}
|
}
|
||||||
case 8: {
|
case 8: {
|
||||||
switch (group_size) {
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
case 64:
|
case 64:
|
||||||
if (transposed_w) {
|
if (transposed_w) {
|
||||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
@ -142,10 +142,11 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||||
const int out_vec_size_g = out_vec_size / group_size;
|
const int out_vec_size_g = out_vec_size / group_size;
|
||||||
int out_col = (tid.y * BN + simd_gid) * el_per_int;
|
int out_col_start = tid.y * (BN * el_per_int);
|
||||||
|
int out_col = out_col_start + simd_gid * el_per_int;
|
||||||
w += out_col / el_per_int;
|
w += out_col / el_per_int;
|
||||||
scales += out_col / group_size;
|
scales += out_col_start / group_size;
|
||||||
biases += out_col / group_size;
|
biases += out_col_start / group_size;
|
||||||
x += tid.z * in_vec_size;
|
x += tid.z * in_vec_size;
|
||||||
y += tid.z * out_vec_size + out_col;
|
y += tid.z * out_vec_size + out_col;
|
||||||
|
|
||||||
@ -155,26 +156,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
|
|
||||||
// Loop over in_vec in blocks of colgroup
|
// Loop over in_vec in blocks of colgroup
|
||||||
for (int i=0; i<in_vec_size; i+=BM) {
|
for (int i=0; i<in_vec_size; i+=BM) {
|
||||||
int offset = simd_lid + i;
|
int offset_lid = simd_lid + i;
|
||||||
bool thread_in_bounds = offset < in_vec_size;
|
int offset_gid = simd_gid + i;
|
||||||
|
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||||
|
bool group_in_bounds = offset_gid < in_vec_size;
|
||||||
|
|
||||||
// Load the vec to shared memory
|
// Load the vec to shared memory
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (simd_gid == 0) {
|
if (simd_gid == 0) {
|
||||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
|
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the scales and biases to shared memory
|
// Load the scales and biases to shared memory
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (simd_gid == 0) {
|
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||||
#pragma clang loop unroll(full)
|
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||||
for (int j=0; j<groups_per_block; j++) {
|
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||||
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
|
|
||||||
}
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int j=0; j<groups_per_block; j++) {
|
|
||||||
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -184,7 +181,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||||
|
|
||||||
// Load the matrix elements
|
// Load the matrix elements
|
||||||
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
|
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||||
|
|
||||||
// Do all the work.
|
// Do all the work.
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@ -543,6 +540,9 @@ instantiate_qmv_types(128, 8)
|
|||||||
instantiate_qmv_types( 64, 2)
|
instantiate_qmv_types( 64, 2)
|
||||||
instantiate_qmv_types( 64, 4)
|
instantiate_qmv_types( 64, 4)
|
||||||
instantiate_qmv_types( 64, 8)
|
instantiate_qmv_types( 64, 8)
|
||||||
|
instantiate_qmv_types( 32, 2)
|
||||||
|
instantiate_qmv_types( 32, 4)
|
||||||
|
instantiate_qmv_types( 32, 8)
|
||||||
|
|
||||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||||
@ -570,6 +570,9 @@ instantiate_qvm_types(128, 8)
|
|||||||
instantiate_qvm_types( 64, 2)
|
instantiate_qvm_types( 64, 2)
|
||||||
instantiate_qvm_types( 64, 4)
|
instantiate_qvm_types( 64, 4)
|
||||||
instantiate_qvm_types( 64, 8)
|
instantiate_qvm_types( 64, 8)
|
||||||
|
instantiate_qvm_types( 32, 2)
|
||||||
|
instantiate_qvm_types( 32, 4)
|
||||||
|
instantiate_qvm_types( 32, 8)
|
||||||
|
|
||||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||||
@ -601,6 +604,9 @@ instantiate_qmm_t_types(128, 8)
|
|||||||
instantiate_qmm_t_types( 64, 2)
|
instantiate_qmm_t_types( 64, 2)
|
||||||
instantiate_qmm_t_types( 64, 4)
|
instantiate_qmm_t_types( 64, 4)
|
||||||
instantiate_qmm_t_types( 64, 8)
|
instantiate_qmm_t_types( 64, 8)
|
||||||
|
instantiate_qmm_t_types( 32, 2)
|
||||||
|
instantiate_qmm_t_types( 32, 4)
|
||||||
|
instantiate_qmm_t_types( 32, 8)
|
||||||
|
|
||||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||||
@ -629,3 +635,6 @@ instantiate_qmm_n_types(128, 8)
|
|||||||
instantiate_qmm_n_types( 64, 2)
|
instantiate_qmm_n_types( 64, 2)
|
||||||
instantiate_qmm_n_types( 64, 4)
|
instantiate_qmm_n_types( 64, 4)
|
||||||
instantiate_qmm_n_types( 64, 8)
|
instantiate_qmm_n_types( 64, 8)
|
||||||
|
instantiate_qmm_n_types( 32, 2)
|
||||||
|
instantiate_qmm_n_types( 32, 4)
|
||||||
|
instantiate_qmm_n_types( 32, 8)
|
||||||
|
@ -2845,7 +2845,7 @@ std::tuple<array, array, array> quantize(
|
|||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (group_size != 64 && group_size != 128) {
|
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The requested group size " << group_size
|
msg << "[quantize] The requested group size " << group_size
|
||||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||||
|
@ -140,7 +140,6 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
probs, targets, with_logits=False, reduction="none"
|
probs, targets, with_logits=False, reduction="none"
|
||||||
)
|
)
|
||||||
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
|
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
|
||||||
print(losses_none, expected_none)
|
|
||||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
# Test with reduction 'mean'
|
# Test with reduction 'mean'
|
||||||
|
@ -10,18 +10,19 @@ import mlx_tests
|
|||||||
class TestQuantized(mlx_tests.MLXTestCase):
|
class TestQuantized(mlx_tests.MLXTestCase):
|
||||||
def test_quantize_dequantize(self):
|
def test_quantize_dequantize(self):
|
||||||
w = mx.random.normal(shape=(128, 512))
|
w = mx.random.normal(shape=(128, 512))
|
||||||
for b in [2, 4, 8]:
|
for gs in [32, 64, 128]:
|
||||||
w_q, scales, biases = mx.quantize(w, 64, b)
|
for b in [2, 4, 8]:
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
w_q, scales, biases = mx.quantize(w, gs, b)
|
||||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
eps = 1e-6
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||||
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
|
eps = 1e-6
|
||||||
|
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
|
||||||
|
|
||||||
def test_qmm(self):
|
def test_qmm(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[2, 4, 8], # bits
|
||||||
[8, 32, 33, 64], # M
|
[8, 32, 33, 64], # M
|
||||||
[512, 1024], # N
|
[512, 1024], # N
|
||||||
@ -75,7 +76,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[2, 4, 8], # bits
|
||||||
[512, 1024], # M
|
[512, 1024], # M
|
||||||
[512, 1024], # N
|
[512, 1024], # N
|
||||||
@ -97,7 +98,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[2, 4, 8], # bits
|
||||||
[512, 1024], # M
|
[512, 1024], # M
|
||||||
[512, 1024], # N
|
[512, 1024], # N
|
||||||
|
Loading…
Reference in New Issue
Block a user