From c7b0300af5c3c203aa3f6578086e3a9f7879f694 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 9 Jan 2025 11:45:57 -0800 Subject: [PATCH] Fix batched qmv bug (#1758) --- mlx/backend/metal/kernels/quantized.h | 24 ++++++++++++++++-------- python/tests/test_quantized.py | 11 ++++++----- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 33eec4910..b45b4bd96 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1323,13 +1323,14 @@ template uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { if (batched) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, - out_vec_size, + out_vec_size * M, x_batch_ndims, x_shape, x_strides, @@ -1374,13 +1375,14 @@ template uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, - out_vec_size, + out_vec_size * M, x_batch_ndims, x_shape, x_strides, @@ -1425,13 +1427,14 @@ template uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, - out_vec_size, + out_vec_size * M, x_batch_ndims, x_shape, x_strides, @@ -1476,13 +1479,14 @@ template uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, - out_vec_size, + out_vec_size * M, x_batch_ndims, x_shape, x_strides, @@ -1527,13 +1531,14 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, - out_vec_size, + out_vec_size * M, x_batch_ndims, x_shape, x_strides, @@ -1706,6 +1711,7 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, @@ -1714,7 +1720,7 @@ template lhs_indices, rhs_indices, y, - out_vec_size, + out_vec_size * M, batch_ndims, batch_shape, lhs_strides, @@ -1767,6 +1773,7 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, @@ -1775,7 +1782,7 @@ template lhs_indices, rhs_indices, y, - out_vec_size, + out_vec_size * M, batch_ndims, batch_shape, lhs_strides, @@ -1828,6 +1835,7 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, @@ -1836,7 +1844,7 @@ template lhs_indices, rhs_indices, y, - out_vec_size, + out_vec_size * M, batch_ndims, batch_shape, lhs_strides, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 7d4ba9949..6630338fc 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase): w_hat = mx.dequantize(w_q, scales, biases) # Test qmv - x = mx.random.normal(shape=(3, 1, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) - y_hat = x @ mx.swapaxes(w_hat, -1, -2) - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + for shape in [(3, 1, 256), (3, 4, 256)]: + x = mx.random.normal(shape=shape) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) # Test qmm_t x = mx.random.normal(shape=(3, 10, 256))