Fix batched qmv bug (#1758)

This commit is contained in:
Alex Barron 2025-01-09 11:45:57 -08:00 committed by GitHub
parent da8c885784
commit c7b0300af5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 13 deletions

View File

@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,

View File

@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(3, 1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
for shape in [(3, 1, 256), (3, 4, 256)]:
x = mx.random.normal(shape=shape)
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(3, 10, 256))