diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 30cf25cdf..90e2dfaf1 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -564,18 +564,21 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); - for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } } for (int row = 0; out_row + row < out_vec_size; row++) {