mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 02:36:42 +08:00
parent
c1fe1ef081
commit
111fefd5e9
@ -619,21 +619,22 @@ METAL_FUNC void qmv_impl(
|
||||
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||
0,
|
||||
values_per_thread);
|
||||
U sum =
|
||||
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
if (remaining > 0) {
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
||||
x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||
wl, x_thread, s, b, sum, remaining);
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||
wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
|
Loading…
Reference in New Issue
Block a user