mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix batched qmv bug (#1758)
This commit is contained in:
		| @@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched> | ||||
|     uint quad_gid [[quadgroup_index_in_threadgroup]], | ||||
|     uint quad_lid [[thread_index_in_quadgroup]]) { | ||||
|   if (batched) { | ||||
|     int M = x_shape[x_batch_ndims]; | ||||
|     adjust_matrix_offsets<T>( | ||||
|         x, | ||||
|         w, | ||||
|         scales, | ||||
|         biases, | ||||
|         y, | ||||
|         out_vec_size, | ||||
|         out_vec_size * M, | ||||
|         x_batch_ndims, | ||||
|         x_shape, | ||||
|         x_strides, | ||||
| @@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched> | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   if (batched) { | ||||
|     int M = x_shape[x_batch_ndims]; | ||||
|     adjust_matrix_offsets<T>( | ||||
|         x, | ||||
|         w, | ||||
|         scales, | ||||
|         biases, | ||||
|         y, | ||||
|         out_vec_size, | ||||
|         out_vec_size * M, | ||||
|         x_batch_ndims, | ||||
|         x_shape, | ||||
|         x_strides, | ||||
| @@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched> | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   if (batched) { | ||||
|     int M = x_shape[x_batch_ndims]; | ||||
|     adjust_matrix_offsets<T>( | ||||
|         x, | ||||
|         w, | ||||
|         scales, | ||||
|         biases, | ||||
|         y, | ||||
|         out_vec_size, | ||||
|         out_vec_size * M, | ||||
|         x_batch_ndims, | ||||
|         x_shape, | ||||
|         x_strides, | ||||
| @@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched> | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   if (batched) { | ||||
|     int M = x_shape[x_batch_ndims]; | ||||
|     adjust_matrix_offsets<T>( | ||||
|         x, | ||||
|         w, | ||||
|         scales, | ||||
|         biases, | ||||
|         y, | ||||
|         out_vec_size, | ||||
|         out_vec_size * M, | ||||
|         x_batch_ndims, | ||||
|         x_shape, | ||||
|         x_strides, | ||||
| @@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32> | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   int M = x_shape[x_batch_ndims]; | ||||
|   adjust_matrix_offsets<T>( | ||||
|       x, | ||||
|       w, | ||||
|       scales, | ||||
|       biases, | ||||
|       y, | ||||
|       out_vec_size, | ||||
|       out_vec_size * M, | ||||
|       x_batch_ndims, | ||||
|       x_shape, | ||||
|       x_strides, | ||||
| @@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits> | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   int M = x_shape[x_batch_ndims]; | ||||
|   adjust_matrix_offsets<T>( | ||||
|       x, | ||||
|       w, | ||||
| @@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits> | ||||
|       lhs_indices, | ||||
|       rhs_indices, | ||||
|       y, | ||||
|       out_vec_size, | ||||
|       out_vec_size * M, | ||||
|       batch_ndims, | ||||
|       batch_shape, | ||||
|       lhs_strides, | ||||
| @@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits> | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   int M = x_shape[x_batch_ndims]; | ||||
|   adjust_matrix_offsets<T>( | ||||
|       x, | ||||
|       w, | ||||
| @@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits> | ||||
|       lhs_indices, | ||||
|       rhs_indices, | ||||
|       y, | ||||
|       out_vec_size, | ||||
|       out_vec_size * M, | ||||
|       batch_ndims, | ||||
|       batch_shape, | ||||
|       lhs_strides, | ||||
| @@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits> | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]) { | ||||
|   int M = x_shape[x_batch_ndims]; | ||||
|   adjust_matrix_offsets<T>( | ||||
|       x, | ||||
|       w, | ||||
| @@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits> | ||||
|       lhs_indices, | ||||
|       rhs_indices, | ||||
|       y, | ||||
|       out_vec_size, | ||||
|       out_vec_size * M, | ||||
|       batch_ndims, | ||||
|       batch_shape, | ||||
|       lhs_strides, | ||||
|   | ||||
| @@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|                 w_hat = mx.dequantize(w_q, scales, biases) | ||||
|  | ||||
|                 # Test qmv | ||||
|                 x = mx.random.normal(shape=(3, 1, 256)) | ||||
|                 y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) | ||||
|                 y_hat = x @ mx.swapaxes(w_hat, -1, -2) | ||||
|                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|                 for shape in [(3, 1, 256), (3, 4, 256)]: | ||||
|                     x = mx.random.normal(shape=shape) | ||||
|                     y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) | ||||
|                     y_hat = x @ mx.swapaxes(w_hat, -1, -2) | ||||
|                     self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                     self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|                 # Test qmm_t | ||||
|                 x = mx.random.normal(shape=(3, 10, 256)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron