mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
OOB QMV fix (#1579)
* fix oob access in qmv * skip more * fix small case
This commit is contained in:
parent
111fefd5e9
commit
a4c47b0276
@ -564,18 +564,21 @@ METAL_FUNC void qmv_impl(
|
|||||||
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||||
0,
|
0,
|
||||||
values_per_thread);
|
values_per_thread);
|
||||||
U sum =
|
if (remaining > 0) {
|
||||||
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
||||||
|
x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl =
|
const device uint8_t* wl =
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
|
Loading…
Reference in New Issue
Block a user