OOB QMV fix (#1579)

* fix oob access in qmv

* skip more

* fix small case
This commit is contained in:
Alex Barron 2024-11-08 17:59:45 -08:00 committed by GitHub
parent 111fefd5e9
commit a4c47b0276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -564,8 +564,9 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0, 0,
values_per_thread); values_per_thread);
U sum = if (remaining > 0) {
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining); U sum = load_vector_safe<T, U, values_per_thread, bits>(
x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = const device uint8_t* wl =
@ -575,7 +576,9 @@ METAL_FUNC void qmv_impl(
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
} }
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {