diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp index e9fec1303..3c1312fbc 100644 --- a/mlx/backend/accelerate/quantized.cpp +++ b/mlx/backend/accelerate/quantized.cpp @@ -18,49 +18,61 @@ void _qmm_t_4_64( const float* biases, int M, int N, - int K) { + int K, + int B, + bool batched_w) { constexpr int bits = 4; constexpr int group_size = 64; constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = 32 / bits; constexpr int packs_in_group = group_size / pack_factor; - for (int m = 0; m < M; m++) { - const uint32_t* w_local = w; - const float* scales_local = scales; - const float* biases_local = biases; + int w_els = N * K / pack_factor; + int g_els = w_els * pack_factor / group_size; - for (int n = 0; n < N; n++) { - const simd_float16* x_local = (simd_float16*)x; - simd_float16 sum = 0; - for (int k = 0; k < K; k += group_size) { - float scale = *scales_local++; - float bias = *biases_local++; + for (int i = 0; i < B; i++) { + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const float* scales_local = scales; + const float* biases_local = biases; - for (int kw = 0; kw < packs_in_group; kw += 2) { - // TODO: vectorize this properly - simd_uint16 wi; - for (int e = 0; e < 2; e++) { - uint32_t wii = *w_local++; - for (int p = 0; p < 8; p++) { - wi[e * 8 + p] = wii & bitmask; - wii >>= bits; + for (int n = 0; n < N; n++) { + const simd_float16* x_local = (simd_float16*)x; + simd_float16 sum = 0; + for (int k = 0; k < K; k += group_size) { + float scale = *scales_local++; + float bias = *biases_local++; + + for (int kw = 0; kw < packs_in_group; kw += 2) { + // TODO: vectorize this properly + simd_uint16 wi; + for (int e = 0; e < 2; e++) { + uint32_t wii = *w_local++; + for (int p = 0; p < 8; p++) { + wi[e * 8 + p] = wii & bitmask; + wii >>= bits; + } } - } - simd_float16 wf = simd_float(wi); - wf *= scale; - wf += bias; + simd_float16 wf = simd_float(wi); + wf *= scale; + wf += bias; - sum += (*x_local) * wf; - x_local++; + sum += (*x_local) * wf; + x_local++; + } } + + *result = simd_reduce_add(sum); + result++; } - *result = simd_reduce_add(sum); - result++; + x += K; + } + if (batched_w) { + w += w_els; + scales += g_els; + biases += g_els; } - - x += K; } } @@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { if (condition) { out.set_data(allocator::malloc_or_wait(out.nbytes())); int K = x.shape(-1); - int M = x.size() / K; + int M = x.shape(-2); int N = out.shape(-1); + int B = x.size() / K / M; + bool batched_w = w.ndim() > 2; _qmm_t_4_64( out.data(), x.data(), @@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { biases.data(), M, N, - K); + K, + B, + batched_w); } else { eval(inputs, out); } diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 7f37e7bc2..daeb50c6a 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -201,55 +201,61 @@ void _qmm_dispatch( int group_size, bool transposed_w) { int K = x.shape(-1); - int M = x.size() / K; + int M = x.shape(-2); int N = out.shape(-1); - switch (x.dtype()) { - case float32: - _qmm_dispatch_typed( - out.data(), - x.data(), - w.data(), - scales.data(), - biases.data(), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - case float16: - _qmm_dispatch_typed( - out.data(), - x.data(), - w.data(), - scales.data(), - biases.data(), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - case bfloat16: - _qmm_dispatch_typed( - out.data(), - x.data(), - w.data(), - scales.data(), - biases.data(), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - default: - throw std::invalid_argument( - "[quantized_matmul] only floating types are supported"); + int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; + int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; + + int batch_size = x.size() / x.shape(-1) / x.shape(-2); + for (int i = 0; i < batch_size; i++) { + switch (x.dtype()) { + case float32: + _qmm_dispatch_typed( + out.data() + i * M * N, + x.data() + elem_to_loc(i * M * K, x), + w.data() + elem_to_loc(i * w_els, w), + scales.data() + elem_to_loc(i * g_els, scales), + biases.data() + elem_to_loc(i * g_els, biases), + M, + N, + K, + bits, + group_size, + transposed_w); + break; + case float16: + _qmm_dispatch_typed( + out.data() + i * M * N, + x.data() + elem_to_loc(i * M * K, x), + w.data() + elem_to_loc(i * w_els, w), + scales.data() + elem_to_loc(i * g_els, scales), + biases.data() + elem_to_loc(i * g_els, biases), + M, + N, + K, + bits, + group_size, + transposed_w); + break; + case bfloat16: + _qmm_dispatch_typed( + out.data() + i * M * N, + x.data() + elem_to_loc(i * M * K, x), + w.data() + elem_to_loc(i * w_els, w), + scales.data() + elem_to_loc(i * g_els, scales), + biases.data() + elem_to_loc(i * g_els, biases), + M, + N, + K, + bits, + group_size, + transposed_w); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } } } diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 4f388b9f3..e8f1c18a2 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -8,6 +8,7 @@ using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; template inline U load_vector(const device T* x, thread U* x_thread) { @@ -371,6 +372,64 @@ struct QuantizedBlockLoader { } }; +template +METAL_FUNC void qmv_quad_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.y * in_vec_size + quad_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device T* sl = scales + row * in_vec_size_g * quads_per_simd; + const device T* bl = biases + row * in_vec_size_g * quads_per_simd; + + U s = sl[0]; + U b = bl[0]; + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + template METAL_FUNC void qmv_fast_impl( const device uint32_t* w, @@ -586,10 +645,10 @@ METAL_FUNC void qmv_impl( template METAL_FUNC void qvm_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, @@ -697,16 +756,16 @@ template < const int BK = 32, const int BN = 32> METAL_FUNC void qmm_t_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, - const constant int& M, - const constant int& N, const constant int& K, + const constant int& N, + const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -818,16 +877,16 @@ template < const int BK = 32, const int BN = 32> METAL_FUNC void qmm_n_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, - const constant int& M, - const constant int& N, const constant int& K, + const constant int& N, + const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -942,6 +1001,45 @@ METAL_FUNC void qmm_n_impl( } } +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant size_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant size_t* w_strides, + const constant size_t* s_strides, + const constant size_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + template METAL_FUNC void adjust_matrix_offsets( const device T*& x, @@ -996,7 +1094,58 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template +[[kernel]] void qmv_quad( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_quad_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid); +} + +template [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1005,9 +1154,35 @@ template device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmv_fast_impl( w, scales, @@ -1021,7 +1196,7 @@ template simd_lid); } -template +template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1030,9 +1205,35 @@ template device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmv_impl( w, scales, @@ -1046,23 +1247,49 @@ template simd_lid); } -template +template [[kernel]] void qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qvm_impl( - x, w, scales, biases, + x, y, in_vec_size, out_vec_size, @@ -1076,18 +1303,27 @@ template < const int group_size, const int bits, const bool aligned_N, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], + const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1099,26 +1335,53 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], + const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1131,8 +1394,27 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template @@ -1141,23 +1423,23 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1202,23 +1484,23 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1259,27 +1541,27 @@ template template [[kernel]] void bs_qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1306,10 +1588,10 @@ template b_strides, tid); qvm_impl( - x, w, scales, biases, + x, y, in_vec_size, out_vec_size, @@ -1327,28 +1609,28 @@ template < const int BK = 32, const int BN = 32> [[kernel]] void bs_qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], + const constant int& batch_ndims [[buffer(16)]], + const constant int* batch_shape [[buffer(17)]], + const device uint32_t* lhs_indices [[buffer(18)]], + const device uint32_t* rhs_indices [[buffer(19)]], + const constant size_t* lhs_strides [[buffer(20)]], + const constant size_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1383,7 +1665,7 @@ template < b_strides, tid); qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < @@ -1394,28 +1676,28 @@ template < const int BK = 32, const int BN = 32> [[kernel]] void bs_qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], + const constant int& batch_ndims [[buffer(16)]], + const constant int* batch_shape [[buffer(17)]], + const device uint32_t* lhs_indices [[buffer(18)]], + const device uint32_t* rhs_indices [[buffer(19)]], + const constant size_t* lhs_strides [[buffer(20)]], + const constant size_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1451,7 +1733,7 @@ template < b_strides, tid); qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 130cdda22..4720a3cda 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -5,67 +5,104 @@ #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized.h" -#define instantiate_quantized(name, type, group_size, bits) \ - instantiate_kernel( \ - #name "_" #type "_gs_" #group_size "_b_" #bits, \ - name, \ - type, \ - group_size, \ +#define instantiate_quantized(name, type, group_size, bits) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits, \ + name, \ + type, \ + group_size, \ bits) -#define instantiate_quantized_types(name, group_size, bits) \ - instantiate_quantized(name, float, group_size, bits) \ - instantiate_quantized(name, float16_t, group_size, bits) \ - instantiate_quantized(name, bfloat16_t, group_size, bits) +#define instantiate_quantized_batched(name, type, group_size, bits, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + batched) -#define instantiate_quantized_groups(name, bits) \ - instantiate_quantized_types(name, 128, bits) \ - instantiate_quantized_types(name, 64, bits) \ - instantiate_quantized_types(name, 32, bits) - -#define instantiate_quantized_all(name) \ - instantiate_quantized_groups(name, 2) \ - instantiate_quantized_groups(name, 4) \ - instantiate_quantized_groups(name, 8) - -instantiate_quantized_all(qmv_fast) -instantiate_quantized_all(qmv) -instantiate_quantized_all(qvm) -instantiate_quantized_all(qmm_n) -instantiate_quantized_all(bs_qmv_fast) -instantiate_quantized_all(bs_qmv) -instantiate_quantized_all(bs_qvm) -instantiate_quantized_all(bs_qmm_n) -instantiate_quantized_all(affine_quantize) -instantiate_quantized_all(affine_quantize_scales_biases) -instantiate_quantized_all(affine_dequantize) - -#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \ - instantiate_kernel( \ - #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \ +#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \ name, \ type, \ group_size, \ bits, \ aligned) -#define instantiate_quantized_types_aligned(name, group_size, bits) \ - instantiate_quantized_aligned(name, float, group_size, bits, true) \ - instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \ - instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \ - instantiate_quantized_aligned(name, float, group_size, bits, false) \ - instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \ - instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false) +#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + aligned, \ + batched) -#define instantiate_quantized_groups_aligned(name, bits) \ - instantiate_quantized_types_aligned(name, 128, bits) \ - instantiate_quantized_types_aligned(name, 64, bits) \ - instantiate_quantized_types_aligned(name, 32, bits) +#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + D, \ + batched) -#define instantiate_quantized_all_aligned(name) \ - instantiate_quantized_groups_aligned(name, 2) \ - instantiate_quantized_groups_aligned(name, 4) \ - instantiate_quantized_groups_aligned(name, 8) \ +#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ + instantiate_quantized_batched(name, type, group_size, bits, 1) \ + instantiate_quantized_batched(name, type, group_size, bits, 0) -instantiate_quantized_all_aligned(qmm_t) -instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on +#define instantiate_quantized_all_batched(type, group_size, bits) \ + instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \ + instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \ + instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \ + instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits) + +#define instantiate_quantized_all_single(type, group_size, bits) \ + instantiate_quantized(affine_quantize, type, group_size, bits) \ + instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \ + instantiate_quantized(affine_dequantize, type, group_size, bits) \ + instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ + instantiate_quantized(bs_qmv, type, group_size, bits) \ + instantiate_quantized(bs_qvm, type, group_size, bits) \ + instantiate_quantized(bs_qmm_n, type, group_size, bits) + +#define instantiate_quantized_all_aligned(type, group_size, bits) \ + instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ + instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ + instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ + instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ + instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ + instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0) + +#define instantiate_quantized_all_quad(type, group_size, bits) \ + instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \ + instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \ + instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \ + instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0) + +#define instantiate_quantized_funcs(type, group_size, bits) \ + instantiate_quantized_all_single(type, group_size, bits) \ + instantiate_quantized_all_batched(type, group_size, bits) \ + instantiate_quantized_all_aligned(type, group_size, bits) \ + instantiate_quantized_all_quad(type, group_size, bits) + +#define instantiate_quantized_types(group_size, bits) \ + instantiate_quantized_funcs(float, group_size, bits) \ + instantiate_quantized_funcs(float16_t, group_size, bits) \ + instantiate_quantized_funcs(bfloat16_t, group_size, bits) + +#define instantiate_quantized_groups(bits) \ + instantiate_quantized_types(128, bits) \ + instantiate_quantized_types(64, bits) \ + instantiate_quantized_types(32, bits) + +#define instantiate_quantized_all() \ + instantiate_quantized_groups(2) \ + instantiate_quantized_groups(4) \ + instantiate_quantized_groups(8) + +instantiate_quantized_all() // clang-format on diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index de70c7562..d1c594cb4 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -12,231 +12,29 @@ namespace mlx::core { -void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - - out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto& s = stream(); - auto& d = metal::device(s.device); - +void launch_qmm( + std::string name, + const std::vector& inputs, + array& out, + int group_size, + int bits, + int D, + int O, + int B, + int N, + MTL::Size& group_dims, + MTL::Size& grid_dims, + bool batched, + bool matrix, + bool gather, + bool aligned, + bool quad, + const Stream& s) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; auto& biases_pre = inputs[3]; - std::vector copies; - auto ensure_row_contiguous = [&copies, &s](const array& arr) { - if (arr.flags().row_contiguous) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous(x_pre); - auto w = ensure_row_contiguous(w_pre); - auto scales = ensure_row_contiguous(scales_pre); - auto biases = ensure_row_contiguous(biases_pre); - - int D = x.shape(-1); - int B = x.size() / D; - int O = out.shape(-1); - if (transpose_) { - // Route to the fast qmv kernel that has no bounds checking - if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "qmv_fast", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - - int bo = 8; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(O / bo, B, 1); - - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder->setBytes(&D, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - - // Route to the qmv kernel - else if (B < 6) { - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "qmv", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - - int bo = 8; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1); - - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder->setBytes(&D, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - - // Route to the qmm_t kernel - else { - std::ostringstream kname; - std::string aligned_n = (O % 32) == 0 ? "true" : "false"; - auto type_string = get_type_string(x.dtype()); - kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_ << "_alN_" << aligned_n; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - int bk = 32; - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1); - - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder->setBytes(&B, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); - compute_encoder->setBytes(&D, sizeof(int), 7); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - } else { - // Route to the qvm kernel - if (B < 4) { - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "qvm", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - - int bo = 64; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(O / bo, B, 1); - - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder->setBytes(&D, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - - // Route to the qmm_n kernel - else { - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "qmm_n", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - int bk = 32; - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); - - if ((O % bn) != 0) { - std::ostringstream msg; - msg << "[quantized_matmul] The output size should be divisible by " - << bn << " but received " << O << "."; - throw std::runtime_error(msg.str()); - } - - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder->setBytes(&B, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); - compute_encoder->setBytes(&D, sizeof(int), 7); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - } - - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } -} - -void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 6); - - out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto& s = stream(); - auto& d = metal::device(s.device); - - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; - auto& lhs_indices = inputs[4]; - auto& rhs_indices = inputs[5]; - - // TODO: collapse batch dims - auto& batch_shape = lhs_indices.shape(); - int batch_ndims = batch_shape.size(); - auto& lhs_strides = lhs_indices.strides(); - auto& rhs_strides = rhs_indices.strides(); - // Ensure that the last two dims are row contiguous. // TODO: Check if we really need this for x as well... std::vector copies; @@ -266,256 +64,205 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { auto& s_strides = scales.strides(); auto& b_strides = biases.strides(); + std::string aligned_n = (O % 32) == 0 ? "true" : "false"; + + std::ostringstream kname; + auto type_string = get_type_string(x.dtype()); + kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits; + if (quad) { + kname << "_d_" << D; + } + if (aligned) { + kname << "_alN_" << aligned_n; + } + if (!gather) { + kname << "_batch_" << batched; + } + + // Encode and dispatch kernel + std::string template_def; + if (quad) { + template_def = get_template_definition( + kname.str(), name, type_string, group_size, bits, D, batched); + } else if (aligned && !gather) { + template_def = get_template_definition( + kname.str(), name, type_string, group_size, bits, aligned_n, batched); + } else if (!gather && !aligned) { + template_def = get_template_definition( + kname.str(), name, type_string, group_size, bits, batched); + } else if (aligned && gather) { + template_def = get_template_definition( + kname.str(), name, type_string, group_size, bits, aligned_n); + } else { + template_def = get_template_definition( + kname.str(), name, type_string, group_size, bits); + } + auto& d = metal::device(s.device); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder->setBytes(&D, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + + int offset = 7; + if (matrix) { + compute_encoder->setBytes(&B, sizeof(int), 7); + offset += 1; + } + + if (batched || gather) { + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset); + set_vector_bytes(compute_encoder, x_shape, offset + 1); + set_vector_bytes(compute_encoder, x_strides, offset + 2); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3); + set_vector_bytes(compute_encoder, w_shape, offset + 4); + set_vector_bytes(compute_encoder, w_strides, offset + 5); + set_vector_bytes(compute_encoder, s_strides, offset + 6); + set_vector_bytes(compute_encoder, b_strides, offset + 7); + } + if (gather) { + auto& lhs_indices = inputs[4]; + auto& rhs_indices = inputs[5]; + + // TODO: collapse batch dims + auto& batch_shape = lhs_indices.shape(); + int batch_ndims = batch_shape.size(); + auto& lhs_strides = lhs_indices.strides(); + auto& rhs_strides = rhs_indices.strides(); + + compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8); + set_vector_bytes(compute_encoder, batch_shape, offset + 9); + compute_encoder.set_input_array(lhs_indices, offset + 10); + compute_encoder.set_input_array(rhs_indices, offset + 11); + set_vector_bytes(compute_encoder, lhs_strides, offset + 12); + set_vector_bytes(compute_encoder, rhs_strides, offset + 13); + } + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); +} + +void qmm_op( + const std::vector& inputs, + array& out, + bool transpose, + int group_size, + int bits, + bool gather, + const Stream& s) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + MTL::Size group_dims; + MTL::Size grid_dims; + + auto& x = inputs[0]; + auto& w = inputs[1]; + bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous); + int D = x.shape(-1); - int B = x.shape(-2); int O = out.shape(-1); - int N = out.size() / B / O; - if (transpose_) { - // Route to the fast bs_qmv kernel that has no bounds checking - if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; + // For the unbatched W case, avoid `adjust_matrix_offsets` + // for a small performance gain. + int B = (batched || gather) ? x.shape(-2) : x.size() / D; + int N = (batched || gather) ? out.size() / B / O : 1; - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "bs_qmv_fast", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); + std::string name = gather ? "bs_" : ""; + bool matrix = false; + bool aligned = false; + bool quad = false; + if (transpose) { + if (B < 6 && (D == 128 || D == 64)) { + name += "qmv_quad"; + constexpr int quads_per_simd = 8; + constexpr int results_per_quadgroup = 8; + int bo = quads_per_simd * results_per_quadgroup; + int simdgroup_size = 32; + group_dims = MTL::Size(simdgroup_size, 1, 1); + grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + quad = true; + } else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { + name += "qmv_fast"; int bo = 8; int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(O / bo, B, N); - - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder->setBytes(&D, sizeof(int), 7); - compute_encoder->setBytes(&O, sizeof(int), 8); - - compute_encoder->setBytes(&batch_ndims, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, lhs_strides, 11); - set_vector_bytes(compute_encoder, rhs_strides, 12); - - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13); - set_vector_bytes(compute_encoder, x_shape, 14); - set_vector_bytes(compute_encoder, x_strides, 15); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16); - set_vector_bytes(compute_encoder, w_shape, 17); - set_vector_bytes(compute_encoder, w_strides, 18); - set_vector_bytes(compute_encoder, s_strides, 19); - set_vector_bytes(compute_encoder, b_strides, 20); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - - else if (B < 6) { - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "bs_qmv", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - + group_dims = MTL::Size(bd, 2, 1); + grid_dims = MTL::Size(O / bo, B, N); + } else if (B < 6) { + name += "qmv"; int bo = 8; int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N); - - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder->setBytes(&D, sizeof(int), 7); - compute_encoder->setBytes(&O, sizeof(int), 8); - - compute_encoder->setBytes(&batch_ndims, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, lhs_strides, 11); - set_vector_bytes(compute_encoder, rhs_strides, 12); - - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13); - set_vector_bytes(compute_encoder, x_shape, 14); - set_vector_bytes(compute_encoder, x_strides, 15); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16); - set_vector_bytes(compute_encoder, w_shape, 17); - set_vector_bytes(compute_encoder, w_strides, 18); - set_vector_bytes(compute_encoder, s_strides, 19); - set_vector_bytes(compute_encoder, b_strides, 20); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - - // Route to the bs_qmm_t - else { - std::ostringstream kname; - std::string aligned_n = (O % 32) == 0 ? "true" : "false"; - auto type_string = get_type_string(out.dtype()); - kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_ << "_alN_" << aligned_n; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - + group_dims = MTL::Size(bd, 2, 1); + grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + } else { int wn = 2; int wm = 2; int bm = 32; int bn = 32; - int bk = 32; - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N); - - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder->setBytes(&B, sizeof(int), 7); - compute_encoder->setBytes(&O, sizeof(int), 8); - compute_encoder->setBytes(&D, sizeof(int), 9); - - compute_encoder->setBytes(&batch_ndims, sizeof(int), 10); - set_vector_bytes(compute_encoder, batch_shape, 11); - set_vector_bytes(compute_encoder, lhs_strides, 12); - set_vector_bytes(compute_encoder, rhs_strides, 13); - - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14); - set_vector_bytes(compute_encoder, x_shape, 15); - set_vector_bytes(compute_encoder, x_strides, 16); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17); - set_vector_bytes(compute_encoder, w_shape, 18); - set_vector_bytes(compute_encoder, w_strides, 19); - set_vector_bytes(compute_encoder, s_strides, 20); - set_vector_bytes(compute_encoder, b_strides, 21); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + group_dims = MTL::Size(32, wn, wm); + grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N); + name += "qmm_t"; + matrix = true; + aligned = true; } } else { - // Route to the bs_qvm kernel if (B < 4) { - std::ostringstream kname; - auto type_string = get_type_string(out.dtype()); - kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "bs_qvm", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - + name += "qvm"; int bo = 64; int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(O / bo, B, N); - - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder->setBytes(&D, sizeof(int), 7); - compute_encoder->setBytes(&O, sizeof(int), 8); - - compute_encoder->setBytes(&batch_ndims, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, lhs_strides, 11); - set_vector_bytes(compute_encoder, rhs_strides, 12); - - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13); - set_vector_bytes(compute_encoder, x_shape, 14); - set_vector_bytes(compute_encoder, x_strides, 15); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16); - set_vector_bytes(compute_encoder, w_shape, 17); - set_vector_bytes(compute_encoder, w_strides, 18); - set_vector_bytes(compute_encoder, s_strides, 19); - set_vector_bytes(compute_encoder, b_strides, 20); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - } - - // Route to bs_qmm_n - else { - std::ostringstream kname; - auto type_string = get_type_string(out.dtype()); - kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto template_def = get_template_definition( - kname.str(), "bs_qmm_n", type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); - + group_dims = MTL::Size(bd, 2, 1); + grid_dims = MTL::Size(O / bo, B, N); + } else { + name += "qmm_n"; int wn = 2; int wm = 2; int bm = 32; int bn = 32; - int bk = 32; - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N); - + group_dims = MTL::Size(32, wn, wm); + grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N); + matrix = true; if ((O % bn) != 0) { std::ostringstream msg; msg << "[quantized_matmul] The output size should be divisible by " << bn << " but received " << O << "."; throw std::runtime_error(msg.str()); } - - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder->setBytes(&B, sizeof(int), 7); - compute_encoder->setBytes(&O, sizeof(int), 8); - compute_encoder->setBytes(&D, sizeof(int), 9); - - compute_encoder->setBytes(&batch_ndims, sizeof(int), 10); - set_vector_bytes(compute_encoder, batch_shape, 11); - set_vector_bytes(compute_encoder, lhs_strides, 12); - set_vector_bytes(compute_encoder, rhs_strides, 13); - - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14); - set_vector_bytes(compute_encoder, x_shape, 15); - set_vector_bytes(compute_encoder, x_strides, 16); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17); - set_vector_bytes(compute_encoder, w_shape, 18); - set_vector_bytes(compute_encoder, w_strides, 19); - set_vector_bytes(compute_encoder, s_strides, 20); - set_vector_bytes(compute_encoder, b_strides, 21); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } } + launch_qmm( + name, + inputs, + out, + group_size, + bits, + D, + O, + B, + N, + group_dims, + grid_dims, + batched, + matrix, + gather, + aligned, + quad, + s); +} + +void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 4); + qmm_op( + inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); +} + +void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 6); + qmm_op( + inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream()); } void fast::AffineQuantize::eval_gpu( diff --git a/mlx/fast.cpp b/mlx/fast.cpp index d203b5bde..731a10bad 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -725,15 +725,6 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { int el_per_int = 32 / bits; - if (w.shape(-1) < 32 * el_per_int) { - std::ostringstream msg; - msg << "[quantize] The feature dimension (2nd dimension of the matrix) is " - << "too small for quantization. We support >=512 for 2 bits, " - << ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has " - << "shape " << w.shape() << "."; - throw std::invalid_argument(msg.str()); - } - auto fallback = [group_size, bits, el_per_int, s]( const std::vector& inputs) -> std::vector { auto& w = inputs[0]; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 680240439..f24aab52a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3592,10 +3592,10 @@ array conv_general( } array quantized_matmul( - const array& x, - const array& w, - const array& scales, - const array& biases, + array x, + array w, + array scales, + array biases, bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, @@ -3604,11 +3604,27 @@ array quantized_matmul( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - if (w.ndim() != 2) { - std::ostringstream msg; - msg << "[quantized_matmul] Batched quantized matmul is not supported for now " - << "received w with shape " << w.shape(); - throw std::invalid_argument(msg.str()); + // QuantizedMatmul handles w.ndim == 2 case. + if (x.ndim() > 2 && w.ndim() > 2) { + std::vector bsx_x(x.shape().begin(), x.shape().end() - 2); + std::vector bsx_w(w.shape().begin(), w.shape().end() - 2); + auto inner_shape = broadcast_shapes(bsx_x, bsx_w); + + // Broadcast x + inner_shape.push_back(x.shape(-2)); + inner_shape.push_back(x.shape(-1)); + x = broadcast_to(x, inner_shape, s); + + // Broadcast w + *(inner_shape.end() - 2) = w.shape(-2); + *(inner_shape.end() - 1) = w.shape(-1); + w = broadcast_to(w, inner_shape, s); + + *(inner_shape.end() - 1) = scales.shape(-1); + scales = broadcast_to(scales, inner_shape, s); + + *(inner_shape.end() - 1) = biases.shape(-1); + biases = broadcast_to(biases, inner_shape, s); } auto dtype = result_type(x, scales, biases); diff --git a/mlx/ops.h b/mlx/ops.h index 7b8e9327e..c2ea9438c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1287,10 +1287,10 @@ array conv_transpose3d( /** Quantized matmul multiplies x with a quantized matrix w*/ array quantized_matmul( - const array& x, - const array& w, - const array& scales, - const array& biases, + array x, + array w, + array scales, + array biases, bool transpose = true, int group_size = 64, int bits = 4, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 47c924c59..2b5251847 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -117,19 +117,24 @@ class TestQuantized(mlx_tests.MLXTestCase): tests = product( [128, 64, 32], # group_size [2, 4, 8], # bits - [512, 1024], # M - [512, 1024], # N + [512, 1024, 67], # M + [64, 128, 512, 1024], # N + [0, 1, 3, 8], # B ) - for group_size, bits, M, N in tests: - with self.subTest(shape=(M, N), group_size=group_size, bits=bits): - x = mx.random.normal(shape=(1, N), key=k1) - w = mx.random.normal(shape=(M, N), key=k2) + for group_size, bits, M, N, B in tests: + if group_size > N: + continue + with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): + x_shape = (3, 1, N) if B == 0 else (B, 1, N) + w_shape = (M, N) if B == 0 else (B, M, N) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul( x, w_q, scales, biases, True, group_size, bits ) - y_hat = x @ w_hat.T + y_hat = x @ mx.swapaxes(w_hat, -1, -2) self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) @@ -140,12 +145,15 @@ class TestQuantized(mlx_tests.MLXTestCase): [128, 64, 32], # group_size [2, 4, 8], # bits [512, 1024], # M - [512, 1024], # N + [512, 1024, 67], # N + [0, 1, 3, 8], # B ) - for group_size, bits, M, N in tests: - with self.subTest(shape=(M, N), group_size=group_size, bits=bits): - x = mx.random.normal(shape=(1, N), key=k1) - w = mx.random.normal(shape=(N, M), key=k2) + for group_size, bits, M, N, B in tests: + with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): + x_shape = (1, N) if B == 0 else (B, 1, N) + w_shape = (N, M) if B == 0 else (B, N, M) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul( @@ -172,37 +180,39 @@ class TestQuantized(mlx_tests.MLXTestCase): mx.eval(y) def test_small_matrix(self): - w = mx.random.normal(shape=(8, 256)) - w_q, scales, biases = mx.quantize(w) - w_hat = mx.dequantize(w_q, scales, biases) + for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]: + with self.subTest(w_shape=w_shape): + w = mx.random.normal(shape=(w_shape)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) - # Test qmv - x = mx.random.normal(shape=(1, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmv + x = mx.random.normal(shape=(3, 1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test qmm_t - x = mx.random.normal(shape=(10, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmm_t + x = mx.random.normal(shape=(3, 10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test qmv - x = mx.random.normal(shape=(1, 8)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) - y_hat = x @ w_hat - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qvm + x = mx.random.normal(shape=(3, 1, 8)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test qmm - x = mx.random.normal(shape=(10, 8)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) - y_hat = x @ w_hat - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmm + x = mx.random.normal(shape=(3, 10, 8)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_non_multiples(self): w = mx.random.normal(shape=(33, 256))