From c2e6d584411f65a86ed444aa2b12bbe23acfa27d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 16 Dec 2024 13:20:01 -0800 Subject: [PATCH] Revert the change in packing order --- mlx/backend/metal/kernels/quantized.h | 11 ++++++----- mlx/ops.cpp | 2 -- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index d52380aed..5e70fc8fe 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2171,10 +2171,11 @@ inline vec partial_qdot_vec(const thread U* x, vec w) { else if (bits == 4) { for (int i = 0; i < 4; i++) { - auto ws = as_type>(w[i]); - for (int j = 0; j < 4; j++) { - accum[j] += - x[2 * i + 0] * (ws[j] & 0x0f) + x[2 * i + 1] * (ws[j] & 0xf0); + auto ws = as_type>(w[i]); + for (int j = 0; j < 2; j++) { + accum[i] += + (x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) + + x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000)); } } } @@ -2183,7 +2184,7 @@ inline vec partial_qdot_vec(const thread U* x, vec w) { for (int i = 0; i < 4; i++) { auto ws = as_type>(w[i]); for (int j = 0; j < 4; j++) { - accum[j] += x[i] * ws[j]; + accum[i] += x[j] * ws[j]; } } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6f4911fec..7f6d296d1 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3795,11 +3795,9 @@ std::tuple> quantize( scales = moveaxis(scales, -2, -1, s); scales = flatten(scales, -2, -1, s); - wq = view(wq, uint8, s); wq = unflatten(wq, -2, {-1, 4}, s); wq = moveaxis(wq, -2, -1, s); wq = flatten(wq, -2, -1, s); - wq = view(wq, uint32, s); return std::make_tuple(wq, scales, std::nullopt); }