From a4c47b0276b551c35cc67f0c7c403c0f890186a5 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 8 Nov 2024 17:59:45 -0800 Subject: [PATCH] OOB QMV fix (#1579) * fix oob access in qmv * skip more * fix small case --- mlx/backend/metal/kernels/quantized.h | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 30cf25cdf..90e2dfaf1 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -564,18 +564,21 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); - for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } } for (int row = 0; out_row + row < out_vec_size; row++) {