diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index f6d3671b43..30cf25cdfc 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -619,21 +619,22 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); - for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot_safe( - wl, x_thread, s, b, sum, remaining); + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } } - for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) {