mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Fix batched qmv bug (#1758)
This commit is contained in:
parent
da8c885784
commit
c7b0300af5
@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
|||||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||||
if (batched) {
|
if (batched) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
x_batch_ndims,
|
x_batch_ndims,
|
||||||
x_shape,
|
x_shape,
|
||||||
x_strides,
|
x_strides,
|
||||||
@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
if (batched) {
|
if (batched) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
x_batch_ndims,
|
x_batch_ndims,
|
||||||
x_shape,
|
x_shape,
|
||||||
x_strides,
|
x_strides,
|
||||||
@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
if (batched) {
|
if (batched) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
x_batch_ndims,
|
x_batch_ndims,
|
||||||
x_shape,
|
x_shape,
|
||||||
x_strides,
|
x_strides,
|
||||||
@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
if (batched) {
|
if (batched) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
x_batch_ndims,
|
x_batch_ndims,
|
||||||
x_shape,
|
x_shape,
|
||||||
x_strides,
|
x_strides,
|
||||||
@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
x_batch_ndims,
|
x_batch_ndims,
|
||||||
x_shape,
|
x_shape,
|
||||||
x_strides,
|
x_strides,
|
||||||
@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
|
|||||||
lhs_indices,
|
lhs_indices,
|
||||||
rhs_indices,
|
rhs_indices,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
batch_ndims,
|
batch_ndims,
|
||||||
batch_shape,
|
batch_shape,
|
||||||
lhs_strides,
|
lhs_strides,
|
||||||
@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
|
|||||||
lhs_indices,
|
lhs_indices,
|
||||||
rhs_indices,
|
rhs_indices,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
batch_ndims,
|
batch_ndims,
|
||||||
batch_shape,
|
batch_shape,
|
||||||
lhs_strides,
|
lhs_strides,
|
||||||
@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
int M = x_shape[x_batch_ndims];
|
||||||
adjust_matrix_offsets<T>(
|
adjust_matrix_offsets<T>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
|
|||||||
lhs_indices,
|
lhs_indices,
|
||||||
rhs_indices,
|
rhs_indices,
|
||||||
y,
|
y,
|
||||||
out_vec_size,
|
out_vec_size * M,
|
||||||
batch_ndims,
|
batch_ndims,
|
||||||
batch_shape,
|
batch_shape,
|
||||||
lhs_strides,
|
lhs_strides,
|
||||||
|
@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w_hat = mx.dequantize(w_q, scales, biases)
|
w_hat = mx.dequantize(w_q, scales, biases)
|
||||||
|
|
||||||
# Test qmv
|
# Test qmv
|
||||||
x = mx.random.normal(shape=(3, 1, 256))
|
for shape in [(3, 1, 256), (3, 4, 256)]:
|
||||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
x = mx.random.normal(shape=shape)
|
||||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
# Test qmm_t
|
# Test qmm_t
|
||||||
x = mx.random.normal(shape=(3, 10, 256))
|
x = mx.random.normal(shape=(3, 10, 256))
|
||||||
|
Loading…
Reference in New Issue
Block a user