mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Fix batched qmv bug (#1758)
This commit is contained in:
parent
da8c885784
commit
c7b0300af5
@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
|
@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(3, 1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
for shape in [(3, 1, 256), (3, 4, 256)]:
|
||||
x = mx.random.normal(shape=shape)
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(3, 10, 256))
|
||||
|
Loading…
Reference in New Issue
Block a user